Skip to content

Commit 86b20cd

Browse files
authored
Merge pull request #1225 from Sherry-XLL/master
FIX: fix UserWarning in get_norm_adj_mat and accelerate csr2tensor
2 parents 4f76169 + 390c305 commit 86b20cd

File tree

4 files changed

+4
-4
lines changed

4 files changed

+4
-4
lines changed

recbole/model/general_recommender/lightgcn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def get_norm_adj_mat(self):
100100
L = sp.coo_matrix(L)
101101
row = L.row
102102
col = L.col
103-
i = torch.LongTensor([row, col])
103+
i = torch.LongTensor(np.array([row, col]))
104104
data = torch.FloatTensor(L.data)
105105
SparseL = torch.sparse.FloatTensor(i, data, torch.Size(L.shape))
106106
return SparseL

recbole/model/general_recommender/ncl.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def get_norm_adj_mat(self):
121121
L = sp.coo_matrix(L)
122122
row = L.row
123123
col = L.col
124-
i = torch.LongTensor([row, col])
124+
i = torch.LongTensor(np.array([row, col]))
125125
data = torch.FloatTensor(L.data)
126126
SparseL = torch.sparse.FloatTensor(i, data, torch.Size(L.shape))
127127
return SparseL

recbole/model/general_recommender/ngcf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def get_norm_adj_mat(self):
103103
L = sp.coo_matrix(L)
104104
row = L.row
105105
col = L.col
106-
i = torch.LongTensor([row, col])
106+
i = torch.LongTensor(np.array([row, col]))
107107
data = torch.FloatTensor(L.data)
108108
SparseL = torch.sparse.FloatTensor(i, data, torch.Size(L.shape))
109109
return SparseL

recbole/model/general_recommender/sgl.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def csr2tensor(self, matrix: sp.csr_matrix):
156156
"""
157157
matrix = matrix.tocoo()
158158
x = torch.sparse.FloatTensor(
159-
torch.LongTensor([matrix.row.tolist(), matrix.col.tolist()]),
159+
torch.LongTensor(np.array([matrix.row, matrix.col])),
160160
torch.FloatTensor(matrix.data.astype(np.float32)), matrix.shape
161161
).to(self.device)
162162
return x

0 commit comments

Comments
 (0)