-
Notifications
You must be signed in to change notification settings - Fork 296
/
Copy pathutils.py
163 lines (137 loc) · 6.64 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import logging as log
from typing import Dict
import torch
import torch.nn as nn
from allennlp.modules import scalar_mix
# huggingface implementation of BERT
import pytorch_pretrained_bert
from jiant.preprocess import parse_task_list_arg
def _get_seg_ids(ids, sep_id):
""" Dynamically build the segment IDs for a concatenated pair of sentences
Searches for index SEP_ID in the tensor
args:
ids (torch.LongTensor): batch of token IDs
returns:
seg_ids (torch.LongTensor): batch of segment IDs
example:
> sents = ["[CLS]", "I", "am", "a", "cat", ".", "[SEP]", "You", "like", "cats", "?", "[SEP]"]
> token_tensor = torch.Tensor([[vocab[w] for w in sent]]) # a tensor of token indices
> seg_ids = _get_seg_ids(token_tensor, sep_id=102) # BERT [SEP] ID
> assert seg_ids == torch.LongTensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
"""
sep_idxs = (ids == sep_id).nonzero()[:, 1]
seg_ids = torch.ones_like(ids)
for row, idx in zip(seg_ids, sep_idxs[::2]):
row[: idx + 1].fill_(0)
return seg_ids
class BertEmbedderModule(nn.Module):
""" Wrapper for BERT module to fit into jiant APIs. """
def __init__(self, args, cache_dir=None):
super(BertEmbedderModule, self).__init__()
self.model = pytorch_pretrained_bert.BertModel.from_pretrained(
args.input_module, cache_dir=cache_dir
)
self.embeddings_mode = args.bert_embeddings_mode
self.num_layers = self.model.config.num_hidden_layers
if args.bert_max_layer >= 0:
self.max_layer = args.bert_max_layer
else:
self.max_layer = self.num_layers
assert self.max_layer <= self.num_layers
tokenizer = pytorch_pretrained_bert.BertTokenizer.from_pretrained(
args.input_module, cache_dir=cache_dir
)
self._sep_id = tokenizer.vocab["[SEP]"]
self._pad_id = tokenizer.vocab["[PAD]"]
# Set trainability of this module.
for param in self.model.parameters():
param.requires_grad = bool(args.transfer_paradigm == "finetune")
# Configure scalar mixing, ELMo-style.
if self.embeddings_mode == "mix":
if args.transfer_paradigm == "frozen":
log.warning(
"NOTE: bert_embeddings_mode='mix', so scalar "
"mixing weights will be fine-tuned even if BERT "
"model is frozen."
)
# TODO: if doing multiple target tasks, allow for multiple sets of
# scalars. See the ELMo implementation here:
# https://github.com/allenai/allennlp/blob/master/allennlp/modules/elmo.py#L115
assert len(parse_task_list_arg(args.target_tasks)) <= 1, (
"bert_embeddings_mode='mix' only supports a single set of "
"scalars (but if you need this feature, see the TODO in "
"the code!)"
)
# Always have one more mixing weight, for lexical layer.
self.scalar_mix = scalar_mix.ScalarMix(self.max_layer + 1, do_layer_norm=False)
def forward(
self, sent: Dict[str, torch.LongTensor], unused_task_name: str = "", is_pair_task=False
) -> torch.FloatTensor:
""" Run BERT to get hidden states.
This forward method does preprocessing on the go,
changing token IDs from preprocessed bert to
what AllenNLP indexes.
Args:
sent: batch dictionary
is_pair_task (bool): true if input is a batch from a pair task
Returns:
h: [batch_size, seq_len, d_emb]
"""
assert "bert_wpm_pretokenized" in sent
# <int32> [batch_size, var_seq_len]
ids = sent["bert_wpm_pretokenized"]
# BERT supports up to 512 tokens; see section 3.2 of https://arxiv.org/pdf/1810.04805.pdf
assert ids.size()[1] <= 512
mask = ids != 0
# "Correct" ids to account for different indexing between BERT and
# AllenNLP.
# The AllenNLP indexer adds a '@@UNKNOWN@@' token to the
# beginning of the vocabulary, *and* treats that as index 1 (index 0 is
# reserved for padding).
ids[ids == 0] = self._pad_id + 2 # Shift the indices that were at 0 to become 2.
# Index 1 should never be used since the BERT WPM uses its own
# unk token, and handles this at the string level before indexing.
assert (ids > 1).all()
ids -= 2 # shift indices to match BERT wordpiece embeddings
if self.embeddings_mode not in ["none", "top"]:
# This is redundant with the lookup inside BertModel,
# but doing so this way avoids the need to modify the BertModel
# code.
# Extract lexical embeddings; see
# https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/modeling.py#L186 # noqa
h_lex = self.model.embeddings.word_embeddings(ids)
h_lex = self.model.embeddings.LayerNorm(h_lex)
# following our use of the OpenAI model, don't use dropout for
# probing. If you would like to use dropout, consider applying
# later on in the SentenceEncoder (see models.py).
# h_lex = self.model.embeddings.dropout(embeddings)
else:
h_lex = None # dummy; should not be accessed.
if self.embeddings_mode != "only":
# encoded_layers is a list of layer activations, each of which is
# <float32> [batch_size, seq_len, output_dim]
token_types = _get_seg_ids(ids, self._sep_id) if is_pair_task else torch.zeros_like(ids)
encoded_layers, _ = self.model(
ids, token_type_ids=token_types, attention_mask=mask, output_all_encoded_layers=True
)
else:
encoded_layers = [] # 'only' mode is embeddings only
all_layers = [h_lex] + encoded_layers
all_layers = all_layers[: self.max_layer + 1]
if self.embeddings_mode in ["none", "top"]:
h = all_layers[-1]
elif self.embeddings_mode == "only":
h = all_layers[0]
elif self.embeddings_mode == "cat":
h = torch.cat([all_layers[-1], all_layers[0]], dim=2)
elif self.embeddings_mode == "mix":
h = self.scalar_mix(all_layers, mask=mask)
else:
raise NotImplementedError(f"embeddings_mode={self.embeddings_mode}" " not supported.")
# <float32> [batch_size, var_seq_len, output_dim]
return h
def get_output_dim(self):
if self.embeddings_mode == "cat":
return 2 * self.model.config.hidden_size
else:
return self.model.config.hidden_size