@@ -43,9 +43,9 @@ def __init__(self, lang, kg_train_data, kg_val_data, kg_test_data, num_entity, n
43
43
self .true_tail = self .get_true_tail (torch .cat ([self .train_data , self .val_data ], dim = 0 )) # [numpy]
44
44
45
45
# Rewrite to torch form, TODO: r here needs to be global indexing.
46
- self .h_train , self .r_train , self .t_train = self .train_data [:, 0 ], self .train_data [:, 1 ]+ self . relation_id_base , self .train_data [:, 2 ]
47
- self .h_val , self .r_val , self .t_val = self .val_data [:, 0 ], self .val_data [:, 1 ]+ self . relation_id_base , self .val_data [:, 2 ]
48
- self .h_test , self .r_test , self .t_test = self .test_data [:, 0 ], self .test_data [:, 1 ]+ self . relation_id_base , self .test_data [:, 2 ]
46
+ self .h_train , self .r_train , self .t_train = self .train_data [:, 0 ], self .train_data [:, 1 ], self .train_data [:, 2 ]
47
+ self .h_val , self .r_val , self .t_val = self .val_data [:, 0 ], self .val_data [:, 1 ], self .val_data [:, 2 ]
48
+ self .h_test , self .r_test , self .t_test = self .test_data [:, 0 ], self .test_data [:, 1 ], self .test_data [:, 2 ]
49
49
50
50
51
51
0 commit comments