Skip to content

Commit 7d828d2

Browse files
committed
updated
updated
1 parent 57cacef commit 7d828d2

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/knowledgegraph_pytorch.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def __init__(self, lang, kg_train_data, kg_val_data, kg_test_data, num_entity, n
4343
self.true_tail = self.get_true_tail(torch.cat([self.train_data, self.val_data], dim=0)) # [numpy]
4444

4545
# Rewrite to torch form, TODO: r here needs to be global indexing.
46-
self.h_train, self.r_train, self.t_train = self.train_data[:, 0], self.train_data[:, 1]+self.relation_id_base, self.train_data[:, 2]
47-
self.h_val, self.r_val, self.t_val = self.val_data[:, 0], self.val_data[:, 1]+self.relation_id_base, self.val_data[:, 2]
48-
self.h_test, self.r_test, self.t_test = self.test_data[:, 0], self.test_data[:, 1]+self.relation_id_base, self.test_data[:, 2]
46+
self.h_train, self.r_train, self.t_train = self.train_data[:, 0], self.train_data[:, 1], self.train_data[:, 2]
47+
self.h_val, self.r_val, self.t_val = self.val_data[:, 0], self.val_data[:, 1], self.val_data[:, 2]
48+
self.h_test, self.r_test, self.t_test = self.test_data[:, 0], self.test_data[:, 1], self.test_data[:, 2]
4949

5050

5151

0 commit comments

Comments
 (0)