-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
smart dep_manager #578
base: master
Are you sure you want to change the base?
smart dep_manager #578
Changes from 250 commits
42f6a75
aadc84b
edbdf37
0ec47bb
139f7f9
3223a27
c47df98
26cd5e9
e47fa35
d8f9e6d
1b9f32e
2751aa9
9896f82
64153ab
d86ef4e
c370598
a3ea5d0
30ca9ee
d0997b4
836a9f4
6646b73
cf74443
da72b63
b062d59
e0e401e
2a0a9af
05a1329
760687e
ad2c703
454331a
6e1cd20
8289d51
1573e1c
14edf7b
565e9ac
c97c204
4eb824f
070c576
1c73235
d53a306
1904df5
a9d3d9e
0dd4ed6
6007eb7
6d0cb1c
86378eb
5b36dd0
90ca97a
08de406
b236337
85e1e24
c86cb53
5d5146f
8640971
58d9810
d02d480
a39928c
cedd9ad
dcfdd9c
cc8c4d2
79045df
7bb1cc9
0e4b19d
f7e97df
0372b7c
3e7f0e0
3eff36e
fb9d37c
0ac9516
2326237
56a0e73
ca2e7bf
5fb1f28
8a6008a
9a364a7
1ad8e96
0f14d99
17beba0
5ec85fd
1386f0b
8bf48e5
69b5f3f
3bc04fa
1544927
495c031
8a41d10
26b4f94
707b404
93c4021
bef055e
bb4e67a
8241a1a
c8421ef
5a65b51
569d09f
1fb98c0
e62c8ab
773ba7d
7901010
f857d2f
ba28dd0
00b1e88
9250f44
916bf4c
f4b8ed8
ee181c2
67e4732
0504e2b
ee08701
974f800
5fe7b87
8793913
c40ad22
31e2a41
b00ab9b
462ae91
aede506
5ba3a83
ba4a398
19d7f46
b90bb8b
82d537e
58d463b
1546db1
0b7dc9f
7f72d09
f87982a
ce3f089
26f1621
0c046bb
ca7ab4a
52a1216
2d49231
f871e10
cc90767
5c66802
9103a7c
63ad9ae
b8c28aa
85f1a70
e7b8137
a771fd6
1d4723d
9db0dd4
68fd472
d10655a
5180b09
d1fe703
e9a0a68
c7a7676
9f095e1
8a67f27
23a2f91
283086c
b78fc6a
c7c715e
03f0fc3
50562b7
75ac2b0
1f13df5
0a8efbf
af5c1bc
30d64b2
167513d
66948c0
8ff98fa
5f02c49
95836a1
542416f
1cb9020
ff38bcc
4d493a9
4272d23
795e6d1
d2728e8
6cb44ab
e41ab49
ff94946
9c73438
581df16
993b9fe
a1b61ac
b57978e
2d8ca8b
e620baf
4e7c9b5
dcff47c
83f55f9
b193706
8e2999a
12914c4
db380ef
3ebea98
6faaa68
bb4b994
cb2b08a
d3684f5
f0b23f4
6d7df64
3462b97
01b2c1a
5db6194
26bc7e8
c9e2852
7d8ec84
110e361
e62c1e0
26f805e
60ec445
286b9cd
ff0a11e
75c8cfa
1e876fb
77066bb
3cef8fd
857fc26
746ef73
0d9ef7b
4afe581
3175279
fc34d3f
171bc9a
4fd9a37
e390d46
5112624
b5fcf36
e96917a
6852295
611d439
806c84c
ee0e61a
03f93dc
a0318e0
e7b2506
6f455d1
1d1274b
598b8b8
d7ccb2a
50569c4
91ba2f1
48dd42b
1897240
11a8996
5288729
378e556
303966a
6332fe1
84f20c3
9744908
31c2300
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,47 +17,26 @@ | |
) | ||
|
||
from .util import setup_logger | ||
|
||
from .utils.dep_manager import deps | ||
|
||
if TYPE_CHECKING: | ||
import scipy | ||
MIXIN_BASE = FeatureMixin | ||
try: | ||
import torch | ||
torch = deps.torch | ||
except: | ||
pass | ||
try: | ||
import dgl | ||
dgl = deps.dgl | ||
except: | ||
pass | ||
else: | ||
MIXIN_BASE = object | ||
|
||
|
||
def lazy_dgl_import_has_dependency(): | ||
try: | ||
import warnings | ||
warnings.filterwarnings('ignore') | ||
import dgl # noqa: F811 | ||
return True, 'ok', dgl | ||
except ModuleNotFoundError as e: | ||
return False, e, None | ||
|
||
|
||
def lazy_torch_import_has_dependency(): | ||
try: | ||
import warnings | ||
warnings.filterwarnings('ignore') | ||
import torch # noqa: F811 | ||
return True, 'ok', torch | ||
except ModuleNotFoundError as e: | ||
return False, e, None | ||
|
||
|
||
logger = setup_logger(name=__name__) | ||
|
||
|
||
|
||
# ######################################################################################### | ||
# | ||
# Torch helpers | ||
|
@@ -73,7 +52,7 @@ def convert_to_torch(X_enc: pd.DataFrame, y_enc: Optional[pd.DataFrame]): # typ | |
:param y_enc: DataFrame Matrix of Values for Target | ||
:return: Dictionary of torch encoded arrays | ||
""" | ||
_, _, torch = lazy_torch_import_has_dependency() # noqa: F811 | ||
torch = deps.torch # noqa: F811 | ||
|
||
if not y_enc.empty: # type: ignore | ||
data = { | ||
|
@@ -98,7 +77,7 @@ def get_available_devices(): | |
device (torch.device): Main device (GPU 0 or CPU). | ||
gpu_ids (list): List of IDs of all GPUs that are available. | ||
""" | ||
_, _, torch = lazy_torch_import_has_dependency() # noqa: F811 | ||
torch = deps.torch # noqa: F811 | ||
|
||
gpu_ids = [] | ||
if torch.cuda.is_available(): | ||
|
@@ -181,7 +160,8 @@ def pandas_to_dgl_graph( | |
sp_mat: sparse scipy matrix | ||
ordered_nodes_dict: dict ordered from most common src and dst nodes | ||
""" | ||
_, _, dgl = lazy_dgl_import_has_dependency() # noqa: F811 | ||
dgl = deps.dgl # noqa: F811 | ||
|
||
sp_mat, ordered_nodes_dict = pandas_to_sparse_adjacency(df, src, dst, weight_col) | ||
g = dgl.from_scipy(sp_mat, device=device) # there are other ways too | ||
logger.info(f"Graph Type: {type(g)}") | ||
|
@@ -196,7 +176,7 @@ def get_torch_train_test_mask(n: int, ratio: float = 0.8): | |
:param ratio: mimics train/test split. `ratio` sets number of True vs False mask entries. | ||
:return: train and test torch tensor masks | ||
""" | ||
_, _, torch = lazy_torch_import_has_dependency() # noqa: F811 | ||
torch = deps.torch # noqa: F811 | ||
|
||
train_mask = torch.zeros(n, dtype=torch.bool).bernoulli(ratio) | ||
test_mask = ~train_mask | ||
|
@@ -225,8 +205,8 @@ def dgl_lazy_init(self, train_split: float = 0.8, device: str = "cpu"): | |
""" | ||
|
||
if not self.dgl_initialized: | ||
lazy_dgl_import_has_dependency() | ||
lazy_torch_import_has_dependency() | ||
deps.dgl | ||
deps.torch | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the point of these calls, we need to throw an exn if missing? If so, can we access the original exn and rethrow? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so long as TYPE_CHECKING is present these will be imported above, so not necessary here |
||
self.train_split = train_split | ||
self.device = device | ||
self._removed_edges_previously = False | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,43 +2,22 @@ | |
import numpy as np | ||
import pandas as pd | ||
from typing import Optional, Union, Callable, List, TYPE_CHECKING, Any, Tuple | ||
|
||
from inspect import getmodule | ||
from .PlotterBase import Plottable | ||
from .compute.ComputeMixin import ComputeMixin | ||
from .utils.dep_manager import deps | ||
|
||
|
||
def lazy_embed_import_dep(): | ||
try: | ||
import torch | ||
import torch.nn as nn | ||
import dgl | ||
from dgl.dataloading import GraphDataLoader | ||
import torch.nn.functional as F | ||
from .networks import HeteroEmbed | ||
from tqdm import trange | ||
return True, torch, nn, dgl, GraphDataLoader, HeteroEmbed, F, trange | ||
|
||
except: | ||
return False, None, None, None, None, None, None, None | ||
|
||
def check_cudf(): | ||
try: | ||
import cudf | ||
return True, cudf | ||
except: | ||
return False, object | ||
|
||
|
||
if TYPE_CHECKING: | ||
_, torch, _, _, _, _, _, _ = lazy_embed_import_dep() | ||
torch = deps.torch | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally these are typed |
||
TT = torch.Tensor | ||
MIXIN_BASE = ComputeMixin | ||
else: | ||
TT = Any | ||
MIXIN_BASE = object | ||
torch = Any | ||
|
||
has_cudf, cudf = check_cudf() | ||
cudf = deps.cudf | ||
|
||
XSymbolic = Optional[Union[List[str], str, pd.DataFrame]] | ||
ProtoSymbolic = Optional[Union[str, Callable[[TT, TT, TT], TT]]] # type: ignore | ||
|
@@ -99,8 +78,7 @@ def __init__(self): | |
self._device = "cpu" | ||
|
||
def _preprocess_embedding_data(self, res, train_split:Union[float, int] = 0.8) -> Plottable: | ||
#_, torch, _, _, _, _, _, _ = lazy_embed_import_dep() | ||
import torch | ||
torch = deps.torch | ||
log('Preprocessing embedding data') | ||
src, dst = res._source, res._destination | ||
relation = res._relation | ||
|
@@ -147,7 +125,7 @@ def _preprocess_embedding_data(self, res, train_split:Union[float, int] = 0.8) - | |
return res | ||
|
||
def _build_graph(self, res) -> Plottable: | ||
_, _, _, dgl, _, _, _, _ = lazy_embed_import_dep() | ||
dgl = deps.dgl | ||
s, r, t = res._triplets.T | ||
|
||
if res._train_idx is not None: | ||
|
@@ -169,7 +147,10 @@ def _build_graph(self, res) -> Plottable: | |
|
||
|
||
def _init_model(self, res, batch_size:int, sample_size:int, num_steps:int, device): | ||
_, _, _, _, GraphDataLoader, HeteroEmbed, _, _ = lazy_embed_import_dep() | ||
dgl_ = deps.dgl | ||
if dgl_: | ||
from dgl.dataloading import GraphDataLoader | ||
from .networks import HeteroEmbed | ||
g_iter = SubgraphIterator(res._kg_dgl, sample_size, num_steps) | ||
g_dataloader = GraphDataLoader( | ||
g_iter, batch_size=batch_size, collate_fn=lambda x: x[0] | ||
|
@@ -186,9 +167,11 @@ def _init_model(self, res, batch_size:int, sample_size:int, num_steps:int, devic | |
) | ||
|
||
return model, g_dataloader | ||
|
||
def _train_embedding(self, res, epochs:int, batch_size:int, lr:float, sample_size:int, num_steps:int, device) -> Plottable: | ||
_, torch, nn, _, _, _, _, trange = lazy_embed_import_dep() | ||
torch = deps.torch | ||
nn = deps.torch.nn | ||
trange = deps.tqdm.trange | ||
log('Training embedding') | ||
model, g_dataloader = res._init_model(res, batch_size, sample_size, num_steps, device) | ||
if hasattr(res, "_embed_model") and not res._build_new_embedding_model: | ||
|
@@ -232,7 +215,7 @@ def _train_embedding(self, res, epochs:int, batch_size:int, lr:float, sample_siz | |
|
||
@property | ||
def _gcn_node_embeddings(self): | ||
_, torch, _, _, _, _, _, _ = lazy_embed_import_dep() | ||
torch = deps.torch | ||
g_dgl = self._kg_dgl.to(self._device) | ||
em = self._embed_model(g_dgl).detach() | ||
torch.cuda.empty_cache() | ||
|
@@ -301,12 +284,12 @@ def embed( | |
""" | ||
# this is temporary, will be fixed in future releases | ||
try: | ||
if isinstance(self._nodes, cudf.DataFrame): | ||
if 'cudf' in str(getmodule(self._nodes)): | ||
self._nodes = self._nodes.to_pandas() | ||
except: | ||
pass | ||
try: | ||
if isinstance(self._edges, cudf.DataFrame): | ||
if 'cudf' in str(getmodule(self._edges)): | ||
self._edges = self._edges.to_pandas() | ||
except: | ||
pass | ||
|
@@ -436,7 +419,7 @@ def predict_links( | |
else: | ||
# this is temporary, will be removed after gpu feature utils | ||
try: | ||
if isinstance(source, cudf.DataFrame): | ||
if 'cudf' in str(getmodule(source)): | ||
source = source.to_pandas() # type: ignore | ||
except: | ||
pass | ||
|
@@ -448,7 +431,7 @@ def predict_links( | |
else: | ||
# this is temporary, will be removed after gpu feature utils | ||
try: | ||
if isinstance(relation, cudf.DataFrame): | ||
if 'cudf' in str(getmodule(relation)): | ||
relation = relation.to_pandas() # type: ignore | ||
except: | ||
pass | ||
|
@@ -460,7 +443,8 @@ def predict_links( | |
else: | ||
# this is temporary, will be removed after gpu feature utils | ||
try: | ||
if isinstance(destination, cudf.DataFrame): | ||
# if isinstance(destination, cudf.DataFrame): | ||
if 'cudf' in str(getmodule(destination)): | ||
destination = destination.to_pandas() # type: ignore | ||
except: | ||
pass | ||
|
@@ -540,7 +524,7 @@ def fetch_triplets_for_inference(x_r): | |
|
||
|
||
def _score(self, triplets: Union[np.ndarray, TT]) -> TT: # type: ignore | ||
_, torch, _, _, _, _, _, _ = lazy_embed_import_dep() | ||
torch = deps.torch | ||
emb = self._kg_embeddings.clone().detach() | ||
if not isinstance(triplets, torch.Tensor): | ||
triplets = torch.tensor(triplets) | ||
|
@@ -571,7 +555,13 @@ def __len__(self) -> int: | |
return self.num_steps | ||
|
||
def __getitem__(self, i:int): | ||
_, torch, nn, dgl, GraphDataLoader, _, F, _ = lazy_embed_import_dep() | ||
torch = deps.torch | ||
from torch import nn | ||
from torch.nn import functional as F | ||
dgl = deps.dgl | ||
|
||
from dgl.dataloading import GraphDataLoader | ||
|
||
eids = torch.from_numpy(np.random.choice(self.eids, self.sample_size)) | ||
|
||
src, dst = self.g.find_edges(eids) | ||
|
@@ -593,7 +583,7 @@ def __getitem__(self, i:int): | |
|
||
@staticmethod | ||
def _sample_neg(triplets:np.ndarray, num_nodes:int) -> Tuple[TT, TT]: # type: ignore | ||
_, torch, _, _, _, _, _, _ = lazy_embed_import_dep() | ||
torch = deps.torch | ||
triplets = torch.tensor(triplets) | ||
h, r, t = triplets.T | ||
h_o_t = torch.randint(high=2, size=h.size()) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
During type checking, we want the typed import