-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathhitter-bert.py
372 lines (307 loc) · 15.2 KB
/
hitter-bert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
import json
import math
import argparse
from pathlib import Path
from transformers import BertTokenizer, BertForMaskedLM, AdamW, get_linear_schedule_with_warmup
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.utilities.seed import seed_everything
from kge.model import KgeModel
from kge.util.io import load_checkpoint
from kge.util import sc
MODEL = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(MODEL)
class FBQADataset(Dataset):
def __init__(self, file_dir):
self.examples = json.load(Path(file_dir).open("rb"))
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
return self.examples[idx]
def fbqa_collate(samples):
questions = []
answers = []
answer_ids = []
entities = []
entity_names = []
relations = []
for item in samples:
q = item["RawQuestion"] + "[MASK]" * len(item["AnswerEntity"]) + "."
questions.append(q)
answers.append(item["AnswerEntity"])
answer_ids.append(item["AnswerEntityID"])
entities.append(item["TopicEntityID"])
entity_names.append(item["TopicEntityName"])
relations.append(item["RelationID"])
questions = tokenizer(questions, return_tensors='pt', padding=True)
entity_names = tokenizer(entity_names, add_special_tokens=False)
answers, answers_lengths = sc.pad_seq_of_seq(answers)
answers = torch.LongTensor(answers)
answers_lengths = torch.LongTensor(answers_lengths)
answer_ids = torch.LongTensor(answer_ids)
input_ids = questions['input_ids']
masked_labels = torch.ones_like(input_ids) * -100
masked_labels[input_ids == tokenizer.mask_token_id] = answers[answers != 0]
entity_mask = torch.zeros_like(input_ids).bool()
entity_span_index = input_ids.new_zeros((len(input_ids), 2))
for i, e_tokens in enumerate(entity_names['input_ids']):
q_tokens = input_ids[i].tolist()
for s_index in range(len(q_tokens) - len(e_tokens)):
if all([e_token == q_tokens[s_index + j] for j, e_token in enumerate(e_tokens)]):
entity_mask[i][s_index:s_index + len(e_tokens)] = True
entity_span_index[i][0] = s_index
entity_span_index[i][1] = s_index + len(e_tokens) - 1
break
entities = torch.LongTensor(entities)
relations = torch.LongTensor(relations)
return questions.data, masked_labels, answers, answers_lengths, answer_ids, entities, relations, entity_mask, entity_span_index
class SelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class CrossAttention(nn.Module):
def __init__(self, config, ctx_hidden_size):
super().__init__()
self.self = CrossAttentionInternal(config, ctx_hidden_size)
self.output = SelfOutput(config)
self.config = config
self.apply(self._init_weights)
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
):
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class CrossAttentionInternal(nn.Module):
def __init__(self, config, ctx_hidden_size):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(ctx_hidden_size, self.all_head_size)
self.value = nn.Linear(ctx_hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
):
mixed_query_layer = self.query(hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
mixed_key_layer = self.key(encoder_hidden_states)
mixed_value_layer = self.value(encoder_hidden_states)
attention_mask = encoder_attention_mask
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, nn.Softmax(dim=-1)(attention_scores)) if output_attentions else (context_layer,)
return outputs
class CrossTrmFinetuner(pl.LightningModule):
def __init__(self, hparams):
super().__init__()
self._hparams = hparams
self.kg_dim = 320
self.bert = BertForMaskedLM.from_pretrained(MODEL)
if self._hparams['use_hitter']:
self.kg_layer_num = 10
self.cross_attentions = nn.ModuleList([CrossAttention(self.bert.config, self.kg_dim)
for _ in range(self.kg_layer_num)])
checkpoint = load_checkpoint('local/best/20200812-174221-trmeh-fb15k237-best/checkpoint_best.pt')
self.hitter = KgeModel.create_from(checkpoint)
def forward(self, batch):
sent_input, masked_labels, batch_labels, label_lens, answer_ids, s, p, entity_mask, entity_span_index = batch
if self._hparams['use_hitter']:
# kg_masks: [bs, 1, 1, length]
# kg_embeds: nlayer*[bs, length, dim]
kg_embeds, kg_masks = self.hitter('get_hitter_repr', s, p)
kg_attentions = [None] * 2 + [(self.cross_attentions[i], kg_embeds[(i + 2) // 2], kg_masks)
for i in range(self.kg_layer_num)]
else:
kg_attentions = []
out = self.bert(kg_attentions=kg_attentions,
output_attentions=True,
output_hidden_states=True,
return_dict=True,
labels=masked_labels,
**sent_input,
)
return out
def training_step(self, batch, batch_idx):
output = self(batch)
loss = output.loss
self.log('train_loss', loss, on_epoch=True, prog_bar=True)
return {'loss': loss}
def validation_step(self, batch, batch_idx):
batch_inputs, masked_labels, batch_labels, label_lens, answer_ids, s, p, entity_mask, _ = batch
output = self(batch)
input_tokens = batch_inputs["input_ids"].clone()
logits = output.logits[masked_labels != -100]
probs = logits.softmax(dim=-1)
values, predictions = probs.topk(1)
hits = []
now_pos = 0
for sample_i, label_length in enumerate(label_lens.tolist()):
failed = False
for i in range(label_length):
if (predictions[now_pos + i] == batch_labels[sample_i][i]).sum() != 1:
failed = True
break
hits += [1] if not failed else [0]
now_pos += label_length
hits = torch.tensor(hits)
input_tokens[input_tokens == tokenizer.mask_token_id] = predictions.flatten()
pred_strings = [str(hits[i].item()) + ' ' + tokenizer.decode(input_tokens[i], skip_special_tokens=True)
for i in range(input_tokens.size(0))]
return {'val_loss': output.loss,
'val_acc': hits.float(),
'pred_strings': pred_strings}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
avg_val_acc = torch.cat([x['val_acc'] for x in outputs]).mean().to(avg_loss.device)
if self.global_rank == 0:
tensorboard = self.logger.experiment
tensorboard.add_text('pred', '\n\n'.join(sum([x['pred_strings'] for x in outputs], [])), self.global_step)
self.log('avg_loss', avg_loss, on_epoch=True, prog_bar=True, sync_dist=True)
self.log('avg_val_acc', avg_val_acc, on_epoch=True, prog_bar=True, sync_dist=True)
return {'val_loss': avg_loss}
def train_dataloader(self):
return DataLoader(FBQADataset(self._hparams['train_dataset']),
self._hparams['batch_size'],
shuffle=True,
collate_fn=fbqa_collate,
num_workers=0)
def val_dataloader(self):
return DataLoader(FBQADataset(self._hparams['val_dataset']),
1,
shuffle=False,
collate_fn=fbqa_collate,
num_workers=0)
def test_dataloader(self):
return DataLoader(FBQADataset(self._hparams['test_dataset']),
1,
shuffle=False,
collate_fn=fbqa_collate,
num_workers=0)
def configure_optimizers(self):
no_decay = ['bias', 'LayerNorm.weight']
no_fine_tune = ['cross_attentions']
pgs = [{'params': [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay) and not any([i in n for i in no_fine_tune])],
'weight_decay': 0.01},
{'params': [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay) and not any([i in n for i in no_fine_tune])],
'weight_decay': 0.0}]
if self._hparams['use_hitter']:
pgs.append({'params': self.cross_attentions.parameters(), 'lr': 5e-5, 'weight_decay': 0.01})
bert_optimizer = AdamW(pgs, lr=5e-6)
bert_scheduler = {
'scheduler': get_linear_schedule_with_warmup(bert_optimizer, self._hparams['max_steps'] // 10, self._hparams['max_steps']),
'interval': 'step',
'monitor': None
}
return [bert_optimizer], [bert_scheduler]
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("exp_name", default='default', nargs='?', help="Name of the experiment")
parser.add_argument('--dataset', choices=['fbqa', 'webqsp'], default='fbqa', help="fbqa or webqsp")
parser.add_argument('--filtered', default=False, action='store_true', help="Filtered or not")
parser.add_argument('--hitter', default=False, action='store_true', help="Use pretrained HittER or not")
parser.add_argument('--seed', default=333, type=int, help='Seed number')
args = parser.parse_args()
seed_everything(args.seed)
QA_DATASET = args.dataset
SUBSET = 'fb15k237-filtered' if args.filtered else 'fb15k237'
hparams = {
'use_hitter': args.hitter,
'batch_size': 16,
'max_epochs': 20,
'train_dataset': f'data/{QA_DATASET}/{SUBSET}/train.json',
'val_dataset': f'data/{QA_DATASET}/{SUBSET}/test.json',
'test_dataset': f'data/{QA_DATASET}/{SUBSET}/test.json',
}
model = CrossTrmFinetuner(hparams)
model.hparams['max_steps'] = (len(model.train_dataloader().dataset) // hparams['batch_size'] + 1) * hparams['max_epochs']
base_path = '/tmp/hitbert-paper'
logger = TensorBoardLogger(base_path, args.exp_name)
checkpoint_callback = ModelCheckpoint(
monitor='avg_val_acc',
dirpath=base_path + '/' + args.exp_name,
filename='{epoch:02d}-{avg_val_acc:.3f}',
save_top_k=1,
mode='max')
trainer = pl.Trainer(gpus=1, accelerator="ddp",
max_epochs=hparams['max_epochs'], max_steps=model.hparams['max_steps'],
checkpoint_callback=True,
gradient_clip_val=1.0, logger=logger,
callbacks=[LearningRateMonitor(), checkpoint_callback])
trainer.fit(model)