Skip to content

Commit dc2a1c5

Browse files
committed
updated
updated
1 parent cae558a commit dc2a1c5

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

src/utils.py

+14
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,20 @@ def save_model(model, output_dir, filename, args):
2323
'args': args,
2424
}, ckpt_path)
2525

26+
def get_negative_samples_alignment(batch_size_each, num_entity,num_negative=None):
27+
'''
28+
Generate one negative sample
29+
:param batch_size_each:
30+
:param num_entity:
31+
:return:
32+
'''
33+
if num_negative == None:
34+
rand_negs = torch.randint(high=num_entity, size=(batch_size_each,)) # [b,n]
35+
else:
36+
rand_negs = torch.randint(high=num_entity, size=(batch_size_each,num_negative)) # [b,n]
37+
38+
return rand_negs
39+
2640

2741
def load_model(ckpt_path, model, device):
2842
if not os.path.exists(ckpt_path):

0 commit comments

Comments
 (0)