Skip to content

Commit 976465e

Browse files
committed
updated
updated
1 parent d4483cb commit 976465e

8 files changed

+1952
-0
lines changed

run_model.py

+408
Large diffs are not rendered by default.

src/data_loader_new.py

+181
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
from os.path import join
2+
import pandas as pd
3+
import numpy as np
4+
import os
5+
from src.knowledgegraph_pytorch import KnowledgeGraph
6+
import numpy as np
7+
from torch.utils.data import Dataset
8+
import torch
9+
from src.utils import get_language_list, get_subgraph_list, subgrarph_list_from_alignment
10+
import copy
11+
12+
13+
14+
class ParseData(object):
15+
def __init__(self, args):
16+
self.data_path = args.data_path + args.dataset
17+
self.data_entity = self.data_path + "/entity/"
18+
self.data_kg = self.data_path + "/kg/"
19+
self.data_align = self.data_path + "/seed_alignlinks/"
20+
self.args = args
21+
22+
self.target_kg = args.target_language
23+
self.kg_names = get_language_list(self.data_path) # all kg names, sorted
24+
self.num_kgs = len(self.kg_names)
25+
26+
27+
28+
def load_data(self):
29+
'''
30+
# NOTE:ORDER IS SORTED(OS.LISTDIR)
31+
32+
:return:
33+
1. X (bert embedding matrix), R (bert embedding matrix)
34+
2. Seed alignment (masked) for calculating alignment loss
35+
3. list of KG object
36+
'''
37+
38+
39+
entity_bert_emb = np.load(self.data_path + "/entity_embeddings.npy")
40+
# normalize features to be within [-1,1]
41+
entity_bert_emb = self.normalize_fature(entity_bert_emb)
42+
kg_object_dict, seeds_masked, seeds_all = self.create_KG_objects_and_alignment()
43+
44+
self.num_relations = kg_object_dict[self.target_kg].num_relation * self.num_kgs
45+
46+
return kg_object_dict, seeds_masked, seeds_all, entity_bert_emb
47+
48+
49+
def normalize_fature(self, input_embedding):
50+
input_max = input_embedding.max()
51+
input_min = input_embedding.min()
52+
53+
# Normalize to [-1, 1]
54+
input_embedding_normalized = (input_embedding - input_min) * 2 / (input_max - input_min) - 1
55+
56+
return input_embedding_normalized
57+
58+
def load_all_to_all_seed_align_links(self):
59+
60+
seeds_preserved = {} # { (lang1, lang2): 2-col np.array }
61+
seeds_masked = {}
62+
seeds_all = {}
63+
for f in os.listdir(self.data_align): # e.g. 'el-en.tsv'
64+
lang1 = f[0:2]
65+
lang2 = f[3:5]
66+
links = pd.read_csv(join(self.data_align, f), sep='\t',header=None).values.astype(int) # [N,2] ndarray
67+
68+
total_link_num = links.shape[0]
69+
if self.args.preserved_ratio != 1.0:
70+
preserved_idx = list(sorted(
71+
np.random.choice(np.arange(total_link_num), int(total_link_num * self.args.preserved_ratio),
72+
replace=False)))
73+
masked_idx = list(filter(lambda x: x not in preserved_idx, np.arange(total_link_num)))
74+
75+
assert len(masked_idx) + len(preserved_idx) == total_link_num
76+
77+
preserved_links = links[preserved_idx, :]
78+
masked_links = links[masked_idx, :]
79+
80+
seeds_masked[(lang1, lang2)] = torch.LongTensor(masked_links)
81+
seeds_all[(lang1, lang2)] = torch.LongTensor(links)
82+
seeds_preserved[(lang1, lang2)] = torch.LongTensor(preserved_links) # to be used to generate the whole graph
83+
else:
84+
seeds_masked[(lang1, lang2)] = None
85+
seeds_all[(lang1, lang2)] = torch.LongTensor(links)
86+
seeds_preserved[(lang1, lang2)] = None
87+
88+
89+
return seeds_masked, seeds_all, seeds_preserved
90+
91+
92+
93+
def create_KG_objects_and_alignment(self):
94+
'''
95+
Local index.
96+
:return:
97+
'''
98+
# INDEX ONLY!
99+
entity_base = 0
100+
relation_base = 0
101+
kg_objects_dict = {}
102+
103+
for lang in self.kg_names:
104+
kg_train_data, kg_val_data, kg_test_data, entity_num, relation_num= self.load_kg_data(lang) # use suffix 1 for supporter kg, 0 for target kg
105+
106+
if lang == self.target_kg:
107+
is_supporter_kg = False
108+
else:
109+
is_supporter_kg = True
110+
111+
kg_each = KnowledgeGraph(lang, kg_train_data, kg_val_data, kg_test_data, entity_num, relation_num, is_supporter_kg,
112+
entity_base, relation_base, self.args.device)
113+
kg_objects_dict[lang] = kg_each
114+
115+
entity_base += entity_num
116+
relation_base += relation_num
117+
118+
self.num_entities = entity_base
119+
120+
# TODO: create subgraph list, using worker if possible
121+
for lang in self.kg_names:
122+
if lang == self.target_kg:
123+
is_target_KG = True
124+
else:
125+
is_target_KG = False
126+
kg_lang = kg_objects_dict[lang]
127+
subgraph_list_self = get_subgraph_list(self.data_path, lang, is_target_KG, kg_lang.num_entity, self.args.num_hop, self.args.k, kg_lang.entity_id_base, kg_lang.relation_id_base)
128+
kg_lang.subgraph_list_kg = subgraph_list_self
129+
kg_lang.subgraph_list_align = copy.deepcopy(kg_lang.subgraph_list_kg)
130+
131+
# TODO: adding alignment links
132+
seeds_masked, seeds_all, seeds_preserved = self.load_all_to_all_seed_align_links()
133+
134+
# Add aligned_links to subgraph_list_kg
135+
self.add_subgraph_list_from_align(seeds_all,kg_objects_dict,is_kg_list = True)
136+
137+
# Add aligned_links to subgraph_list_align
138+
self.add_subgraph_list_from_align(seeds_preserved,kg_objects_dict,is_kg_list = False)
139+
140+
return kg_objects_dict, seeds_masked,seeds_all
141+
142+
143+
def add_subgraph_list_from_align(self, seeds, kg_objects_dict, is_kg_list = False):
144+
145+
for (kg0_name, kg1_name) in seeds:
146+
kg0 = kg_objects_dict[kg0_name]
147+
kg1 = kg_objects_dict[kg1_name]
148+
align_links = seeds[(kg0_name, kg1_name)]
149+
subgrarph_list_from_alignment(align_links, kg0, kg1,is_kg_list)
150+
151+
152+
153+
def load_kg_data(self, language):
154+
"""
155+
Load triples and stats for each single KG
156+
:return: triples (n_triple, 3) np.int np.array
157+
TODO: change indexing to global one.
158+
"""
159+
160+
train_df = pd.read_csv(join(self.data_kg, language + '-train.tsv'), sep='\t', header=None,names=['v1', 'relation', 'v2'])
161+
val_df = pd.read_csv(join(self.data_kg, language + '-val.tsv'), sep='\t', header=None,names=['v1', 'relation', 'v2'])
162+
test_df = pd.read_csv(join(self.data_kg, language + '-test.tsv'), sep='\t', header=None,names=['v1', 'relation', 'v2'])
163+
164+
# count entity num
165+
f = open(self.data_entity + language + '.tsv')
166+
lines = f.readlines()
167+
f.close()
168+
169+
entity_num = len(lines) # TODO: check whetehr need to +1/-1
170+
171+
relation_list = [line.rstrip() for line in open(join(self.data_path, 'relations.txt'))]
172+
relation_num = len(relation_list) + 1
173+
174+
triples_train = train_df.values.astype(np.int)
175+
triples_val = val_df.values.astype(np.int)
176+
triples_test = test_df.values.astype(np.int)
177+
178+
return torch.LongTensor(triples_train), torch.LongTensor(triples_val), torch.LongTensor(triples_test), entity_num, relation_num
179+
180+
181+

0 commit comments

Comments
 (0)