Skip to content

Commit 0046faa

Browse files
committed
transform functions fix
1 parent 38dc0ce commit 0046faa

File tree

1 file changed

+91
-1
lines changed

1 file changed

+91
-1
lines changed

utils/tranform_functions.py

+91-1
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,95 @@ def generate_ngram_sequences(data, seq_len_right, seq_len_left):
339339
sequence_dict[key] = left_seq + right_seq
340340
i += 1
341341
return sequence_dict
342+
343+
def validate_sequences(sequence_dict, seq_len_right, seq_len_left):
344+
micro_sequences = []
345+
macro_sequences = {}
346+
347+
for key in sequence_dict.keys():
348+
score = sequence_dict[key]
349+
350+
if score < 1 and len(key.split()) <= seq_len_right:
351+
micro_sequences.append(key)
352+
else:
353+
macro_sequences[key] = score
354+
355+
non_frag_sequences = []
356+
macro_sequences_copy = macro_sequences.copy()
357+
358+
for sent in tqdm(micro_sequences, total = len(micro_sequences)):
359+
for key in macro_sequences.keys():
360+
if sent in key:
361+
non_frag_sequences.append(key)
362+
del macro_sequences_copy[key]
363+
364+
macro_sequences = macro_sequences_copy.copy()
365+
366+
for sent in non_frag_sequences:
367+
macro_sequences[sent] = 0
368+
369+
for sent in micro_sequences:
370+
macro_sequences[sent] = 0
371+
372+
return macro_sequences
373+
374+
def create_fragment_detection_tsv(dataDir, readFile, wrtDir, transParamDict, isTrainFile=False):
375+
376+
"""
377+
This function transforms data for fragment detection task (detecting whether a sentence is incomplete/fragment or not).
378+
It takes data in single sentence classification format and creates fragment samples from the sentences.
379+
In the transformed file, label 1 and 0 represent fragment and non-fragment sentence respectively.
380+
Following transformed files are written at wrtDir
381+
382+
- Fragment transformed tsv file containing fragment/non-fragment sentences and labels
383+
384+
385+
For using this transform function, set ``transform_func`` : **create_fragment_detection_tsv** in transform file.
386+
Args:
387+
dataDir (:obj:`str`) : Path to the directory where the raw data files to be read are present..
388+
readFile (:obj:`str`) : This is the file which is currently being read and transformed by the function.
389+
wrtDir (:obj:`str`) : Path to the directory where to save the transformed tsv files.
390+
transParamDict (:obj:`dict`, defaults to :obj:`None`): Dictionary requiring the following parameters as key-value
391+
392+
- ``data_frac`` (defaults to 0.2) : Fraction of data to consider for making fragments.
393+
- ``seq_len_right`` : (defaults to 3) : Right window length for making n-grams.
394+
- ``seq_len_left`` (defaults to 2) : Left window length for making n-grams.
395+
- ``sep`` (defaults to "\t") : column separator for input file.
396+
- ``query_col`` (defaults to 2) : column number containing sentences. Counting starts from 0.
397+
398+
"""
399+
400+
transParamDict.setdefault("data_frac", 0.2)
401+
transParamDict.setdefault("seq_len_right", 3)
402+
transParamDict.setdefault("seq_len_left", 2)
403+
transParamDict.setdefault("sep", "\t")
404+
transParamDict.setdefault("query_col", 2)
405+
406+
allDataDf = pd.read_csv(os.path.join(dataDir, readFile), sep=transParamDict["sep"], header=None)
407+
sampledDataDf = allDataDf.sample(frac = float(transParamDict['data_frac']), random_state=42)
408+
409+
#2nd column is considered to have queries in dataframe, 0th uid, 1st label
410+
# making n-gram with left and right window
411+
seqDict = generate_ngram_sequences(data = list(sampledDataDf.iloc[:, int(transParamDict["query_col"])]),
412+
seq_len_right = transParamDict['seq_len_right'],
413+
seq_len_left = transParamDict['seq_len_left'])
414+
415+
fragDict = validate_sequences(seqDict, seq_len_right = transParamDict['seq_len_right'],
416+
seq_len_left = transParamDict['seq_len_left'])
417+
418+
finalDf = pd.DataFrame({'uid' : [i for i in range(len(fragDict)+len(allDataDf))],
419+
'label' : [1]*len(fragDict)+[0]*len(allDataDf),
420+
'query' : list(fragDict.keys())+list(allDataDf.iloc[:, int(transParamDict["query_col"]) ]) })
421+
422+
print('number of fragment samples : ', len(fragDict))
423+
print('number of non-fragment samples : ', len(allDataDf))
424+
# saving
425+
print('writing fragment file for {} at {}'.format(readFile, wrtDir))
426+
427+
finalDf.to_csv(os.path.join(wrtDir, 'fragment_{}.tsv'.format(readFile.split('.')[0])), sep='\t',
428+
index=False, header=False)
429+
430+
342431
def msmarco_answerability_detection_to_tsv(dataDir, readFile, wrtDir, transParamDict, isTrainFile=False):
343432
"""
344433
This function transforms the MSMARCO triples data available at `triples <https://msmarco.blob.core.windows.net/msmarcoranking/triples.train.small.tar.gz>`_
@@ -412,7 +501,7 @@ def msmarco_answerability_detection_to_tsv(dataDir, readFile, wrtDir, transParam
412501

413502
devDf.to_csv(os.path.join(wrtDir, 'msmarco_answerability_test.tsv'), sep='\t', index=False, header=False)
414503
print('Test file written at: ', os.path.join(wrtDir, 'msmarco_answerability_test.tsv'))
415-
504+
416505
def msmarco_query_type_to_tsv(dataDir, readFile, wrtDir, transParamDict, isTrainFile=False):
417506

418507
"""
@@ -458,6 +547,7 @@ def msmarco_query_type_to_tsv(dataDir, readFile, wrtDir, transParamDict, isTrain
458547
labelMapPath = os.path.join(wrtDir, 'querytype_{}_label_map.joblib'.format(readFile.lower().replace('.json', '')))
459548
joblib.dump(labelMap, labelMapPath)
460549
print('Created label map file at', labelMapPath)
550+
461551

462552
def imdb_sentiment_data_to_tsv(dataDir, readFile, wrtDir, transParamDict, isTrainFile=False):
463553

0 commit comments

Comments
 (0)