|
| 1 | +from os.path import join |
| 2 | +import pandas as pd |
| 3 | +import numpy as np |
| 4 | +import os |
| 5 | +from src.knowledgegraph_pytorch import KnowledgeGraph |
| 6 | +import numpy as np |
| 7 | +from torch.utils.data import Dataset |
| 8 | +import torch |
| 9 | +from src.utils import get_language_list, get_subgraph_list, subgrarph_list_from_alignment |
| 10 | +import copy |
| 11 | + |
| 12 | + |
| 13 | + |
| 14 | +class ParseData(object): |
| 15 | + def __init__(self, args): |
| 16 | + self.data_path = args.data_path + args.dataset |
| 17 | + self.data_entity = self.data_path + "/entity/" |
| 18 | + self.data_kg = self.data_path + "/kg/" |
| 19 | + self.data_align = self.data_path + "/seed_alignlinks/" |
| 20 | + self.args = args |
| 21 | + |
| 22 | + self.target_kg = args.target_language |
| 23 | + self.kg_names = get_language_list(self.data_path) # all kg names, sorted |
| 24 | + self.num_kgs = len(self.kg_names) |
| 25 | + |
| 26 | + |
| 27 | + |
| 28 | + def load_data(self): |
| 29 | + ''' |
| 30 | + # NOTE:ORDER IS SORTED(OS.LISTDIR) |
| 31 | +
|
| 32 | + :return: |
| 33 | + 1. X (bert embedding matrix), R (bert embedding matrix) |
| 34 | + 2. Seed alignment (masked) for calculating alignment loss |
| 35 | + 3. list of KG object |
| 36 | + ''' |
| 37 | + |
| 38 | + |
| 39 | + entity_bert_emb = np.load(self.data_path + "/entity_embeddings.npy") |
| 40 | + # normalize features to be within [-1,1] |
| 41 | + entity_bert_emb = self.normalize_fature(entity_bert_emb) |
| 42 | + kg_object_dict, seeds_masked, seeds_all = self.create_KG_objects_and_alignment() |
| 43 | + |
| 44 | + self.num_relations = kg_object_dict[self.target_kg].num_relation * self.num_kgs |
| 45 | + |
| 46 | + return kg_object_dict, seeds_masked, seeds_all, entity_bert_emb |
| 47 | + |
| 48 | + |
| 49 | + def normalize_fature(self, input_embedding): |
| 50 | + input_max = input_embedding.max() |
| 51 | + input_min = input_embedding.min() |
| 52 | + |
| 53 | + # Normalize to [-1, 1] |
| 54 | + input_embedding_normalized = (input_embedding - input_min) * 2 / (input_max - input_min) - 1 |
| 55 | + |
| 56 | + return input_embedding_normalized |
| 57 | + |
| 58 | + def load_all_to_all_seed_align_links(self): |
| 59 | + |
| 60 | + seeds_preserved = {} # { (lang1, lang2): 2-col np.array } |
| 61 | + seeds_masked = {} |
| 62 | + seeds_all = {} |
| 63 | + for f in os.listdir(self.data_align): # e.g. 'el-en.tsv' |
| 64 | + lang1 = f[0:2] |
| 65 | + lang2 = f[3:5] |
| 66 | + links = pd.read_csv(join(self.data_align, f), sep='\t',header=None).values.astype(int) # [N,2] ndarray |
| 67 | + |
| 68 | + total_link_num = links.shape[0] |
| 69 | + if self.args.preserved_ratio != 1.0: |
| 70 | + preserved_idx = list(sorted( |
| 71 | + np.random.choice(np.arange(total_link_num), int(total_link_num * self.args.preserved_ratio), |
| 72 | + replace=False))) |
| 73 | + masked_idx = list(filter(lambda x: x not in preserved_idx, np.arange(total_link_num))) |
| 74 | + |
| 75 | + assert len(masked_idx) + len(preserved_idx) == total_link_num |
| 76 | + |
| 77 | + preserved_links = links[preserved_idx, :] |
| 78 | + masked_links = links[masked_idx, :] |
| 79 | + |
| 80 | + seeds_masked[(lang1, lang2)] = torch.LongTensor(masked_links) |
| 81 | + seeds_all[(lang1, lang2)] = torch.LongTensor(links) |
| 82 | + seeds_preserved[(lang1, lang2)] = torch.LongTensor(preserved_links) # to be used to generate the whole graph |
| 83 | + else: |
| 84 | + seeds_masked[(lang1, lang2)] = None |
| 85 | + seeds_all[(lang1, lang2)] = torch.LongTensor(links) |
| 86 | + seeds_preserved[(lang1, lang2)] = None |
| 87 | + |
| 88 | + |
| 89 | + return seeds_masked, seeds_all, seeds_preserved |
| 90 | + |
| 91 | + |
| 92 | + |
| 93 | + def create_KG_objects_and_alignment(self): |
| 94 | + ''' |
| 95 | + Local index. |
| 96 | + :return: |
| 97 | + ''' |
| 98 | + # INDEX ONLY! |
| 99 | + entity_base = 0 |
| 100 | + relation_base = 0 |
| 101 | + kg_objects_dict = {} |
| 102 | + |
| 103 | + for lang in self.kg_names: |
| 104 | + kg_train_data, kg_val_data, kg_test_data, entity_num, relation_num= self.load_kg_data(lang) # use suffix 1 for supporter kg, 0 for target kg |
| 105 | + |
| 106 | + if lang == self.target_kg: |
| 107 | + is_supporter_kg = False |
| 108 | + else: |
| 109 | + is_supporter_kg = True |
| 110 | + |
| 111 | + kg_each = KnowledgeGraph(lang, kg_train_data, kg_val_data, kg_test_data, entity_num, relation_num, is_supporter_kg, |
| 112 | + entity_base, relation_base, self.args.device) |
| 113 | + kg_objects_dict[lang] = kg_each |
| 114 | + |
| 115 | + entity_base += entity_num |
| 116 | + relation_base += relation_num |
| 117 | + |
| 118 | + self.num_entities = entity_base |
| 119 | + |
| 120 | + # TODO: create subgraph list, using worker if possible |
| 121 | + for lang in self.kg_names: |
| 122 | + if lang == self.target_kg: |
| 123 | + is_target_KG = True |
| 124 | + else: |
| 125 | + is_target_KG = False |
| 126 | + kg_lang = kg_objects_dict[lang] |
| 127 | + subgraph_list_self = get_subgraph_list(self.data_path, lang, is_target_KG, kg_lang.num_entity, self.args.num_hop, self.args.k, kg_lang.entity_id_base, kg_lang.relation_id_base) |
| 128 | + kg_lang.subgraph_list_kg = subgraph_list_self |
| 129 | + kg_lang.subgraph_list_align = copy.deepcopy(kg_lang.subgraph_list_kg) |
| 130 | + |
| 131 | + # TODO: adding alignment links |
| 132 | + seeds_masked, seeds_all, seeds_preserved = self.load_all_to_all_seed_align_links() |
| 133 | + |
| 134 | + # Add aligned_links to subgraph_list_kg |
| 135 | + self.add_subgraph_list_from_align(seeds_all,kg_objects_dict,is_kg_list = True) |
| 136 | + |
| 137 | + # Add aligned_links to subgraph_list_align |
| 138 | + self.add_subgraph_list_from_align(seeds_preserved,kg_objects_dict,is_kg_list = False) |
| 139 | + |
| 140 | + return kg_objects_dict, seeds_masked,seeds_all |
| 141 | + |
| 142 | + |
| 143 | + def add_subgraph_list_from_align(self, seeds, kg_objects_dict, is_kg_list = False): |
| 144 | + |
| 145 | + for (kg0_name, kg1_name) in seeds: |
| 146 | + kg0 = kg_objects_dict[kg0_name] |
| 147 | + kg1 = kg_objects_dict[kg1_name] |
| 148 | + align_links = seeds[(kg0_name, kg1_name)] |
| 149 | + subgrarph_list_from_alignment(align_links, kg0, kg1,is_kg_list) |
| 150 | + |
| 151 | + |
| 152 | + |
| 153 | + def load_kg_data(self, language): |
| 154 | + """ |
| 155 | + Load triples and stats for each single KG |
| 156 | + :return: triples (n_triple, 3) np.int np.array |
| 157 | + TODO: change indexing to global one. |
| 158 | + """ |
| 159 | + |
| 160 | + train_df = pd.read_csv(join(self.data_kg, language + '-train.tsv'), sep='\t', header=None,names=['v1', 'relation', 'v2']) |
| 161 | + val_df = pd.read_csv(join(self.data_kg, language + '-val.tsv'), sep='\t', header=None,names=['v1', 'relation', 'v2']) |
| 162 | + test_df = pd.read_csv(join(self.data_kg, language + '-test.tsv'), sep='\t', header=None,names=['v1', 'relation', 'v2']) |
| 163 | + |
| 164 | + # count entity num |
| 165 | + f = open(self.data_entity + language + '.tsv') |
| 166 | + lines = f.readlines() |
| 167 | + f.close() |
| 168 | + |
| 169 | + entity_num = len(lines) # TODO: check whetehr need to +1/-1 |
| 170 | + |
| 171 | + relation_list = [line.rstrip() for line in open(join(self.data_path, 'relations.txt'))] |
| 172 | + relation_num = len(relation_list) + 1 |
| 173 | + |
| 174 | + triples_train = train_df.values.astype(np.int) |
| 175 | + triples_val = val_df.values.astype(np.int) |
| 176 | + triples_test = test_df.values.astype(np.int) |
| 177 | + |
| 178 | + return torch.LongTensor(triples_train), torch.LongTensor(triples_val), torch.LongTensor(triples_test), entity_num, relation_num |
| 179 | + |
| 180 | + |
| 181 | + |
0 commit comments