forked from RUCAIBox/RecBole
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsgl.py
281 lines (232 loc) · 11.7 KB
/
sgl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
# -*- coding: utf-8 -*-
# @Time : 2021/10/12
# @Author : Tian Zhen
# @Email : [email protected]
r"""
SGL
################################################
Reference:
Jiancan Wu et al. "SGL: Self-supervised Graph Learning for Recommendation" in SIGIR 2021.
Reference code:
https://github.com/wujcan/SGL
"""
import numpy as np
import scipy.sparse as sp
import torch
from recbole.model.abstract_recommender import GeneralRecommender
from recbole.model.init import xavier_uniform_initialization
from recbole.model.loss import BPRLoss, EmbLoss
from recbole.utils import InputType
import torch.nn.functional as F
class SGL(GeneralRecommender):
r"""SGL is a GCN-based recommender model.
SGL supplements the classical supervised task of recommendation with an auxiliary
self supervised task, which reinforces node representation learning via self-
discrimination.Specifically,SGL generates multiple views of a node, maximizing the
agreement between different views of the same node compared to that of other nodes.
SGL devises three operators to generate the views — node dropout, edge dropout, and
random walk — that change the graph structure in different manners.
We implement the model following the original author with a pairwise training mode.
"""
input_type = InputType.PAIRWISE
def __init__(self, config, dataset):
super(SGL, self).__init__(config, dataset)
self._user = dataset.inter_feat[dataset.uid_field]
self._item = dataset.inter_feat[dataset.iid_field]
self.embed_dim = config["embedding_size"]
self.n_layers = int(config["n_layers"])
self.type = config["type"]
self.drop_ratio = config["drop_ratio"]
self.ssl_tau = config["ssl_tau"]
self.reg_weight = config["reg_weight"]
self.ssl_weight = config["ssl_weight"]
self.user_embedding = torch.nn.Embedding(self.n_users, self.embed_dim)
self.item_embedding = torch.nn.Embedding(self.n_items, self.embed_dim)
self.reg_loss = EmbLoss()
self.train_graph = self.csr2tensor(self.create_adjust_matrix(is_sub=False))
self.restore_user_e = None
self.restore_item_e = None
self.apply(xavier_uniform_initialization)
self.other_parameter_name = ['restore_user_e', 'restore_item_e']
def graph_construction(self):
r"""Devise three operators to generate the views — node dropout, edge dropout, and random walk of a node.
"""
self.sub_graph1 = []
if self.type == "ND" or self.type == "ED":
self.sub_graph1 = self.csr2tensor(self.create_adjust_matrix(is_sub=True))
elif self.type == "RW":
for i in range(self.n_layers):
_g = self.csr2tensor(self.create_adjust_matrix(is_sub=True))
self.sub_graph1.append(_g)
self.sub_graph2 = []
if self.type == "ND" or self.type == "ED":
self.sub_graph2 = self.csr2tensor(self.create_adjust_matrix(is_sub=True))
elif self.type == "RW":
for i in range(self.n_layers):
_g = self.csr2tensor(self.create_adjust_matrix(is_sub=True))
self.sub_graph2.append(_g)
def rand_sample(self, high, size=None, replace=True):
r"""Randomly discard some points or edges.
Args:
high (int): Upper limit of index value
size (int): Array size after sampling
Returns:
numpy.ndarray: Array index after sampling, shape: [size]
"""
a = np.arange(high)
sample = np.random.choice(a, size=size, replace=replace)
return sample
def create_adjust_matrix(self, is_sub: bool):
r"""Get the normalized interaction matrix of users and items.
Construct the square matrix from the training data and normalize it
using the laplace matrix.If it is a subgraph, it may be processed by
node dropout or edge dropout.
.. math::
A_{hat} = D^{-0.5} \times A \times D^{-0.5}
Returns:
csr_matrix of the normalized interaction matrix.
"""
matrix = None
if not is_sub:
ratings = np.ones_like(self._user, dtype=np.float32)
matrix = sp.csr_matrix((ratings, (self._user, self._item + self.n_users)),
shape=(self.n_users + self.n_items, self.n_users + self.n_items))
else:
if self.type == "ND":
drop_user = self.rand_sample(self.n_users, size=int(self.n_users * self.drop_ratio), replace=False)
drop_item = self.rand_sample(self.n_items, size=int(self.n_items * self.drop_ratio), replace=False)
R_user = np.ones(self.n_users, dtype=np.float32)
R_user[drop_user] = 0.
R_item = np.ones(self.n_items, dtype=np.float32)
R_item[drop_item] = 0.
R_user = sp.diags(R_user)
R_item = sp.diags(R_item)
R_G = sp.csr_matrix((np.ones_like(self._user, dtype=np.float32), (self._user, self._item)),
shape=(self.n_users, self.n_items))
res = R_user.dot(R_G)
res = res.dot(R_item)
user, item = res.nonzero()
ratings = res.data
matrix = sp.csr_matrix((ratings, (user, item + self.n_users)), shape=(self.n_users + self.n_items, self.n_users + self.n_items))
elif self.type == "ED" or self.type == "RW":
keep_item = self.rand_sample(
len(self._user), size=int(len(self._user) * (1 - self.drop_ratio)), replace=False
)
user = self._user[keep_item]
item = self._item[keep_item]
matrix = sp.csr_matrix((np.ones_like(user), (user, item + self.n_users)),
shape=(self.n_users + self.n_items, self.n_users + self.n_items))
matrix = matrix + matrix.T
D = np.array(matrix.sum(axis=1)) + 1e-7
D = np.power(D, -0.5).flatten()
D = sp.diags(D)
return D.dot(matrix).dot(D)
def csr2tensor(self, matrix: sp.csr_matrix):
r"""Convert csr_matrix to tensor.
Args:
matrix (scipy.csr_matrix): Sparse matrix to be converted.
Returns:
torch.sparse.FloatTensor: Transformed sparse matrix.
"""
matrix = matrix.tocoo()
x = torch.sparse.FloatTensor(
torch.LongTensor(np.array([matrix.row, matrix.col])),
torch.FloatTensor(matrix.data.astype(np.float32)), matrix.shape
).to(self.device)
return x
def forward(self, graph):
main_ego = torch.cat([self.user_embedding.weight, self.item_embedding.weight])
all_ego = [main_ego]
if isinstance(graph, list):
for sub_graph in graph:
main_ego = torch.sparse.mm(sub_graph, main_ego)
all_ego.append(main_ego)
else:
for i in range(self.n_layers):
main_ego = torch.sparse.mm(graph, main_ego)
all_ego.append(main_ego)
all_ego = torch.stack(all_ego, dim=1)
all_ego = torch.mean(all_ego, dim=1, keepdim=False)
user_emd, item_emd = torch.split(all_ego, [self.n_users, self.n_items], dim=0)
return user_emd, item_emd
def calculate_loss(self, interaction):
if self.restore_user_e is not None or self.restore_item_e is not None:
self.restore_user_e, self.restore_item_e = None, None
user_list = interaction[self.USER_ID]
pos_item_list = interaction[self.ITEM_ID]
neg_item_list = interaction[self.NEG_ITEM_ID]
user_emd, item_emd = self.forward(self.train_graph)
user_sub1, item_sub1 = self.forward(self.sub_graph1)
user_sub2, item_sub2 = self.forward(self.sub_graph2)
total_loss = self.calc_bpr_loss(user_emd,item_emd,user_list,pos_item_list,neg_item_list) + \
self.calc_ssl_loss(user_list,pos_item_list,user_sub1,user_sub2,item_sub1,item_sub2)
return total_loss
def calc_bpr_loss(self, user_emd, item_emd, user_list, pos_item_list, neg_item_list):
r"""Calculate the the pairwise Bayesian Personalized Ranking (BPR) loss and parameter regularization loss.
Args:
user_emd (torch.Tensor): Ego embedding of all users after forwarding.
item_emd (torch.Tensor): Ego embedding of all items after forwarding.
user_list (torch.Tensor): List of the user.
pos_item_list (torch.Tensor): List of positive examples.
neg_item_list (torch.Tensor): List of negative examples.
Returns:
torch.Tensor: Loss of BPR tasks and parameter regularization.
"""
u_e = user_emd[user_list]
pi_e = item_emd[pos_item_list]
ni_e = item_emd[neg_item_list]
p_scores = torch.mul(u_e, pi_e).sum(dim=1)
n_scores = torch.mul(u_e, ni_e).sum(dim=1)
l1 = torch.sum(-F.logsigmoid(p_scores - n_scores))
u_e_p = self.user_embedding(user_list)
pi_e_p = self.item_embedding(pos_item_list)
ni_e_p = self.item_embedding(neg_item_list)
l2 = self.reg_loss(u_e_p, pi_e_p, ni_e_p)
return l1 + l2 * self.reg_weight
def calc_ssl_loss(self, user_list, pos_item_list, user_sub1, user_sub2, item_sub1, item_sub2):
r"""Calculate the loss of self-supervised tasks.
Args:
user_list (torch.Tensor): List of the user.
pos_item_list (torch.Tensor): List of positive examples.
user_sub1 (torch.Tensor): Ego embedding of all users in the first subgraph after forwarding.
user_sub2 (torch.Tensor): Ego embedding of all users in the second subgraph after forwarding.
item_sub1 (torch.Tensor): Ego embedding of all items in the first subgraph after forwarding.
item_sub2 (torch.Tensor): Ego embedding of all items in the second subgraph after forwarding.
Returns:
torch.Tensor: Loss of self-supervised tasks.
"""
u_emd1 = F.normalize(user_sub1[user_list], dim=1)
u_emd2 = F.normalize(user_sub2[user_list], dim=1)
all_user2 = F.normalize(user_sub2,dim=1)
v1 = torch.sum(u_emd1 * u_emd2, dim=1)
v2 = u_emd1.matmul(all_user2.T)
v1 = torch.exp(v1 / self.ssl_tau)
v2 = torch.sum(torch.exp(v2 / self.ssl_tau), dim=1)
ssl_user = -torch.sum(torch.log(v1 / v2))
i_emd1 = F.normalize(item_sub1[pos_item_list], dim=1)
i_emd2 = F.normalize(item_sub2[pos_item_list], dim=1)
all_item2 = F.normalize(item_sub2,dim=1)
v3 = torch.sum(i_emd1 * i_emd2, dim=1)
v4 = i_emd1.matmul(all_item2.T)
v3 = torch.exp(v3 / self.ssl_tau)
v4 = torch.sum(torch.exp(v4 / self.ssl_tau), dim=1)
ssl_item = -torch.sum(torch.log(v3 / v4))
return (ssl_item + ssl_user) * self.ssl_weight
def predict(self, interaction):
if self.restore_user_e is None or self.restore_item_e is None:
self.restore_user_e, self.restore_item_e = self.forward(self.train_graph)
user = self.restore_user_e[interaction[self.USER_ID]]
item = self.restore_item_e[interaction[self.ITEM_ID]]
return torch.sum(user * item, dim=1)
def full_sort_predict(self, interaction):
if self.restore_user_e is None or self.restore_item_e is None:
self.restore_user_e, self.restore_item_e = self.forward(self.train_graph)
user = self.restore_user_e[interaction[self.USER_ID]]
return user.matmul(self.restore_item_e.T)
def train(self, mode: bool = True):
r"""Override train method of base class.The subgraph is reconstructed each time it is called.
"""
T = super().train(mode=mode)
if mode:
self.graph_construction()
return T