add attention to all bert models and add test
This commit is contained in:
@@ -813,15 +813,20 @@ class BertForPreTraining(BertPreTrainedModel):
|
|||||||
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
|
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def __init__(self, config):
|
def __init__(self, config, output_attentions=False):
|
||||||
super(BertForPreTraining, self).__init__(config)
|
super(BertForPreTraining, self).__init__(config)
|
||||||
self.bert = BertModel(config)
|
self.output_attentions = output_attentions
|
||||||
|
self.bert = BertModel(config, output_attentions=output_attentions)
|
||||||
self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
|
self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
|
||||||
self.apply(self.init_bert_weights)
|
self.apply(self.init_bert_weights)
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None):
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None):
|
||||||
sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
|
outputs = self.bert(input_ids, token_type_ids, attention_mask,
|
||||||
output_all_encoded_layers=False)
|
output_all_encoded_layers=False)
|
||||||
|
if self.output_attentions:
|
||||||
|
all_attentions, sequence_output, pooled_output = outputs
|
||||||
|
else:
|
||||||
|
sequence_output, pooled_output = outputs
|
||||||
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
||||||
|
|
||||||
if masked_lm_labels is not None and next_sentence_label is not None:
|
if masked_lm_labels is not None and next_sentence_label is not None:
|
||||||
@@ -830,7 +835,8 @@ class BertForPreTraining(BertPreTrainedModel):
|
|||||||
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
||||||
total_loss = masked_lm_loss + next_sentence_loss
|
total_loss = masked_lm_loss + next_sentence_loss
|
||||||
return total_loss
|
return total_loss
|
||||||
else:
|
elif self.output_attentions:
|
||||||
|
return all_attentions, prediction_scores, seq_relationship_score
|
||||||
return prediction_scores, seq_relationship_score
|
return prediction_scores, seq_relationship_score
|
||||||
|
|
||||||
|
|
||||||
@@ -876,22 +882,28 @@ class BertForMaskedLM(BertPreTrainedModel):
|
|||||||
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
|
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def __init__(self, config):
|
def __init__(self, config, output_attentions=False):
|
||||||
super(BertForMaskedLM, self).__init__(config)
|
super(BertForMaskedLM, self).__init__(config)
|
||||||
self.bert = BertModel(config)
|
self.output_attentions = output_attentions
|
||||||
|
self.bert = BertModel(config, output_attentions=output_attentions)
|
||||||
self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
|
self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
|
||||||
self.apply(self.init_bert_weights)
|
self.apply(self.init_bert_weights)
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None):
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None):
|
||||||
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask,
|
outputs = self.bert(input_ids, token_type_ids, attention_mask,
|
||||||
output_all_encoded_layers=False)
|
output_all_encoded_layers=False)
|
||||||
|
if self.output_attentions:
|
||||||
|
all_attentions, sequence_output, _ = outputs
|
||||||
|
else:
|
||||||
|
sequence_output, _ = outputs
|
||||||
prediction_scores = self.cls(sequence_output)
|
prediction_scores = self.cls(sequence_output)
|
||||||
|
|
||||||
if masked_lm_labels is not None:
|
if masked_lm_labels is not None:
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
||||||
return masked_lm_loss
|
return masked_lm_loss
|
||||||
else:
|
elif self.output_attentions:
|
||||||
|
return all_attentions, prediction_scores
|
||||||
return prediction_scores
|
return prediction_scores
|
||||||
|
|
||||||
|
|
||||||
@@ -938,22 +950,28 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
|||||||
seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
|
seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def __init__(self, config):
|
def __init__(self, config, output_attentions=False):
|
||||||
super(BertForNextSentencePrediction, self).__init__(config)
|
super(BertForNextSentencePrediction, self).__init__(config)
|
||||||
self.bert = BertModel(config)
|
self.output_attentions = output_attentions
|
||||||
|
self.bert = BertModel(config, output_attentions=output_attentions)
|
||||||
self.cls = BertOnlyNSPHead(config)
|
self.cls = BertOnlyNSPHead(config)
|
||||||
self.apply(self.init_bert_weights)
|
self.apply(self.init_bert_weights)
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None):
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None):
|
||||||
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
|
outputs = self.bert(input_ids, token_type_ids, attention_mask,
|
||||||
output_all_encoded_layers=False)
|
output_all_encoded_layers=False)
|
||||||
seq_relationship_score = self.cls( pooled_output)
|
if self.output_attentions:
|
||||||
|
all_attentions, _, pooled_output = outputs
|
||||||
|
else:
|
||||||
|
_, pooled_output = outputs
|
||||||
|
seq_relationship_score = self.cls(pooled_output)
|
||||||
|
|
||||||
if next_sentence_label is not None:
|
if next_sentence_label is not None:
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||||
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
||||||
return next_sentence_loss
|
return next_sentence_loss
|
||||||
else:
|
elif self.output_attentions:
|
||||||
|
return all_attentions, seq_relationship_score
|
||||||
return seq_relationship_score
|
return seq_relationship_score
|
||||||
|
|
||||||
|
|
||||||
@@ -1002,16 +1020,21 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|||||||
logits = model(input_ids, token_type_ids, input_mask)
|
logits = model(input_ids, token_type_ids, input_mask)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def __init__(self, config, num_labels):
|
def __init__(self, config, num_labels, output_attentions=False):
|
||||||
super(BertForSequenceClassification, self).__init__(config)
|
super(BertForSequenceClassification, self).__init__(config)
|
||||||
|
self.output_attentions = output_attentions
|
||||||
self.num_labels = num_labels
|
self.num_labels = num_labels
|
||||||
self.bert = BertModel(config)
|
self.bert = BertModel(config, output_attentions=output_attentions)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
||||||
self.apply(self.init_bert_weights)
|
self.apply(self.init_bert_weights)
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
|
||||||
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
|
outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
|
||||||
|
if self.output_attentions:
|
||||||
|
all_attentions, _, pooled_output = outputs
|
||||||
|
else:
|
||||||
|
_, pooled_output = outputs
|
||||||
pooled_output = self.dropout(pooled_output)
|
pooled_output = self.dropout(pooled_output)
|
||||||
logits = self.classifier(pooled_output)
|
logits = self.classifier(pooled_output)
|
||||||
|
|
||||||
@@ -1019,7 +1042,8 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
return loss
|
return loss
|
||||||
else:
|
elif self.output_attentions:
|
||||||
|
return all_attentions, logits
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
@@ -1067,10 +1091,11 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
|||||||
logits = model(input_ids, token_type_ids, input_mask)
|
logits = model(input_ids, token_type_ids, input_mask)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def __init__(self, config, num_choices):
|
def __init__(self, config, num_choices, output_attentions=False):
|
||||||
super(BertForMultipleChoice, self).__init__(config)
|
super(BertForMultipleChoice, self).__init__(config)
|
||||||
|
self.output_attentions = output_attentions
|
||||||
self.num_choices = num_choices
|
self.num_choices = num_choices
|
||||||
self.bert = BertModel(config)
|
self.bert = BertModel(config, output_attentions=output_attentions)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||||
self.apply(self.init_bert_weights)
|
self.apply(self.init_bert_weights)
|
||||||
@@ -1079,7 +1104,11 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
|||||||
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
|
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
|
||||||
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
||||||
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||||
_, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False)
|
outputs = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False)
|
||||||
|
if self.output_attentions:
|
||||||
|
all_attentions, _, pooled_output = outputs
|
||||||
|
else:
|
||||||
|
_, pooled_output = outputs
|
||||||
pooled_output = self.dropout(pooled_output)
|
pooled_output = self.dropout(pooled_output)
|
||||||
logits = self.classifier(pooled_output)
|
logits = self.classifier(pooled_output)
|
||||||
reshaped_logits = logits.view(-1, self.num_choices)
|
reshaped_logits = logits.view(-1, self.num_choices)
|
||||||
@@ -1088,7 +1117,8 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
|||||||
loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
loss = loss_fct(reshaped_logits, labels)
|
loss = loss_fct(reshaped_logits, labels)
|
||||||
return loss
|
return loss
|
||||||
else:
|
elif self.output_attentions:
|
||||||
|
return all_attentions, reshaped_logits
|
||||||
return reshaped_logits
|
return reshaped_logits
|
||||||
|
|
||||||
|
|
||||||
@@ -1137,16 +1167,21 @@ class BertForTokenClassification(BertPreTrainedModel):
|
|||||||
logits = model(input_ids, token_type_ids, input_mask)
|
logits = model(input_ids, token_type_ids, input_mask)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def __init__(self, config, num_labels):
|
def __init__(self, config, num_labels, output_attentions=False):
|
||||||
super(BertForTokenClassification, self).__init__(config)
|
super(BertForTokenClassification, self).__init__(config)
|
||||||
|
self.output_attentions = output_attentions
|
||||||
self.num_labels = num_labels
|
self.num_labels = num_labels
|
||||||
self.bert = BertModel(config)
|
self.bert = BertModel(config, output_attentions=output_attentions)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
||||||
self.apply(self.init_bert_weights)
|
self.apply(self.init_bert_weights)
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
|
||||||
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
|
outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
|
||||||
|
if self.output_attentions:
|
||||||
|
all_attentions, sequence_output, _ = outputs
|
||||||
|
else:
|
||||||
|
sequence_output, _ = outputs
|
||||||
sequence_output = self.dropout(sequence_output)
|
sequence_output = self.dropout(sequence_output)
|
||||||
logits = self.classifier(sequence_output)
|
logits = self.classifier(sequence_output)
|
||||||
|
|
||||||
@@ -1161,7 +1196,8 @@ class BertForTokenClassification(BertPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
return loss
|
return loss
|
||||||
else:
|
elif self.output_attentions:
|
||||||
|
return all_attentions, logits
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
@@ -1212,16 +1248,19 @@ class BertForQuestionAnswering(BertPreTrainedModel):
|
|||||||
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
|
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def __init__(self, config):
|
def __init__(self, config, output_attentions=False):
|
||||||
super(BertForQuestionAnswering, self).__init__(config)
|
super(BertForQuestionAnswering, self).__init__(config)
|
||||||
self.bert = BertModel(config)
|
self.output_attentions = output_attentions
|
||||||
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
|
self.bert = BertModel(config, output_attentions=output_attentions)
|
||||||
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
||||||
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||||
self.apply(self.init_bert_weights)
|
self.apply(self.init_bert_weights)
|
||||||
|
|
||||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None):
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None):
|
||||||
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
|
outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
|
||||||
|
if self.output_attentions:
|
||||||
|
all_attentions, sequence_output, _ = outputs
|
||||||
|
else:
|
||||||
|
sequence_output, _ = outputs
|
||||||
logits = self.qa_outputs(sequence_output)
|
logits = self.qa_outputs(sequence_output)
|
||||||
start_logits, end_logits = logits.split(1, dim=-1)
|
start_logits, end_logits = logits.split(1, dim=-1)
|
||||||
start_logits = start_logits.squeeze(-1)
|
start_logits = start_logits.squeeze(-1)
|
||||||
@@ -1243,5 +1282,6 @@ class BertForQuestionAnswering(BertPreTrainedModel):
|
|||||||
end_loss = loss_fct(end_logits, end_positions)
|
end_loss = loss_fct(end_logits, end_positions)
|
||||||
total_loss = (start_loss + end_loss) / 2
|
total_loss = (start_loss + end_loss) / 2
|
||||||
return total_loss
|
return total_loss
|
||||||
else:
|
elif self.output_attentions:
|
||||||
|
return all_attentions, start_logits, end_logits
|
||||||
return start_logits, end_logits
|
return start_logits, end_logits
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ import torch
|
|||||||
from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM,
|
from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM,
|
||||||
BertForNextSentencePrediction, BertForPreTraining,
|
BertForNextSentencePrediction, BertForPreTraining,
|
||||||
BertForQuestionAnswering, BertForSequenceClassification,
|
BertForQuestionAnswering, BertForSequenceClassification,
|
||||||
BertForTokenClassification)
|
BertForTokenClassification, BertForMultipleChoice)
|
||||||
from pytorch_pretrained_bert.modeling import PRETRAINED_MODEL_ARCHIVE_MAP
|
from pytorch_pretrained_bert.modeling import PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
|
|
||||||
|
|
||||||
@@ -56,6 +56,7 @@ class BertModelTest(unittest.TestCase):
|
|||||||
type_sequence_label_size=2,
|
type_sequence_label_size=2,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
num_labels=3,
|
num_labels=3,
|
||||||
|
num_choices=4,
|
||||||
scope=None):
|
scope=None):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
@@ -77,6 +78,7 @@ class BertModelTest(unittest.TestCase):
|
|||||||
self.type_sequence_label_size = type_sequence_label_size
|
self.type_sequence_label_size = type_sequence_label_size
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
self.num_labels = num_labels
|
self.num_labels = num_labels
|
||||||
|
self.num_choices = num_choices
|
||||||
self.scope = scope
|
self.scope = scope
|
||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
@@ -92,9 +94,11 @@ class BertModelTest(unittest.TestCase):
|
|||||||
|
|
||||||
sequence_labels = None
|
sequence_labels = None
|
||||||
token_labels = None
|
token_labels = None
|
||||||
|
choice_labels = None
|
||||||
if self.use_labels:
|
if self.use_labels:
|
||||||
sequence_labels = BertModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size)
|
sequence_labels = BertModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||||
token_labels = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
token_labels = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||||
|
choice_labels = BertModelTest.ids_tensor([self.batch_size], self.num_choices)
|
||||||
|
|
||||||
config = BertConfig(
|
config = BertConfig(
|
||||||
vocab_size_or_config_json_file=self.vocab_size,
|
vocab_size_or_config_json_file=self.vocab_size,
|
||||||
@@ -109,14 +113,14 @@ class BertModelTest(unittest.TestCase):
|
|||||||
type_vocab_size=self.type_vocab_size,
|
type_vocab_size=self.type_vocab_size,
|
||||||
initializer_range=self.initializer_range)
|
initializer_range=self.initializer_range)
|
||||||
|
|
||||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels
|
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
|
|
||||||
def check_loss_output(self, result):
|
def check_loss_output(self, result):
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["loss"].size()),
|
list(result["loss"].size()),
|
||||||
[])
|
[])
|
||||||
|
|
||||||
def create_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
|
def create_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||||
model = BertModel(config=config)
|
model = BertModel(config=config)
|
||||||
model.eval()
|
model.eval()
|
||||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
|
||||||
@@ -137,7 +141,7 @@ class BertModelTest(unittest.TestCase):
|
|||||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||||
|
|
||||||
|
|
||||||
def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
|
def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||||
model = BertForMaskedLM(config=config)
|
model = BertForMaskedLM(config=config)
|
||||||
model.eval()
|
model.eval()
|
||||||
loss = model(input_ids, token_type_ids, input_mask, token_labels)
|
loss = model(input_ids, token_type_ids, input_mask, token_labels)
|
||||||
@@ -153,7 +157,7 @@ class BertModelTest(unittest.TestCase):
|
|||||||
list(result["prediction_scores"].size()),
|
list(result["prediction_scores"].size()),
|
||||||
[self.batch_size, self.seq_length, self.vocab_size])
|
[self.batch_size, self.seq_length, self.vocab_size])
|
||||||
|
|
||||||
def create_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
|
def create_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||||
model = BertForNextSentencePrediction(config=config)
|
model = BertForNextSentencePrediction(config=config)
|
||||||
model.eval()
|
model.eval()
|
||||||
loss = model(input_ids, token_type_ids, input_mask, sequence_labels)
|
loss = model(input_ids, token_type_ids, input_mask, sequence_labels)
|
||||||
@@ -170,7 +174,7 @@ class BertModelTest(unittest.TestCase):
|
|||||||
[self.batch_size, 2])
|
[self.batch_size, 2])
|
||||||
|
|
||||||
|
|
||||||
def create_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
|
def create_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||||
model = BertForPreTraining(config=config)
|
model = BertForPreTraining(config=config)
|
||||||
model.eval()
|
model.eval()
|
||||||
loss = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels)
|
loss = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels)
|
||||||
@@ -191,7 +195,7 @@ class BertModelTest(unittest.TestCase):
|
|||||||
[self.batch_size, 2])
|
[self.batch_size, 2])
|
||||||
|
|
||||||
|
|
||||||
def create_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
|
def create_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||||
model = BertForQuestionAnswering(config=config)
|
model = BertForQuestionAnswering(config=config)
|
||||||
model.eval()
|
model.eval()
|
||||||
loss = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels)
|
loss = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels)
|
||||||
@@ -212,7 +216,7 @@ class BertModelTest(unittest.TestCase):
|
|||||||
[self.batch_size, self.seq_length])
|
[self.batch_size, self.seq_length])
|
||||||
|
|
||||||
|
|
||||||
def create_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
|
def create_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||||
model = BertForSequenceClassification(config=config, num_labels=self.num_labels)
|
model = BertForSequenceClassification(config=config, num_labels=self.num_labels)
|
||||||
model.eval()
|
model.eval()
|
||||||
loss = model(input_ids, token_type_ids, input_mask, sequence_labels)
|
loss = model(input_ids, token_type_ids, input_mask, sequence_labels)
|
||||||
@@ -229,7 +233,7 @@ class BertModelTest(unittest.TestCase):
|
|||||||
[self.batch_size, self.num_labels])
|
[self.batch_size, self.num_labels])
|
||||||
|
|
||||||
|
|
||||||
def create_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
|
def create_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||||
model = BertForTokenClassification(config=config, num_labels=self.num_labels)
|
model = BertForTokenClassification(config=config, num_labels=self.num_labels)
|
||||||
model.eval()
|
model.eval()
|
||||||
loss = model(input_ids, token_type_ids, input_mask, token_labels)
|
loss = model(input_ids, token_type_ids, input_mask, token_labels)
|
||||||
@@ -246,6 +250,49 @@ class BertModelTest(unittest.TestCase):
|
|||||||
[self.batch_size, self.seq_length, self.num_labels])
|
[self.batch_size, self.seq_length, self.num_labels])
|
||||||
|
|
||||||
|
|
||||||
|
def create_bert_for_multiple_choice(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||||
|
model = BertForMultipleChoice(config=config, num_choices=self.num_choices)
|
||||||
|
model.eval()
|
||||||
|
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||||
|
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||||
|
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||||
|
loss = model(multiple_choice_inputs_ids,
|
||||||
|
multiple_choice_token_type_ids,
|
||||||
|
multiple_choice_input_mask,
|
||||||
|
choice_labels)
|
||||||
|
logits = model(multiple_choice_inputs_ids,
|
||||||
|
multiple_choice_token_type_ids,
|
||||||
|
multiple_choice_input_mask)
|
||||||
|
outputs = {
|
||||||
|
"loss": loss,
|
||||||
|
"logits": logits,
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def check_bert_for_multiple_choice(self, result):
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(result["logits"].size()),
|
||||||
|
[self.batch_size, self.num_choices])
|
||||||
|
|
||||||
|
|
||||||
|
def create_and_check_bert_for_attentions(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||||
|
for model_class in (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
|
||||||
|
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
|
||||||
|
BertForTokenClassification):
|
||||||
|
if model_class in [BertForSequenceClassification,
|
||||||
|
BertForTokenClassification]:
|
||||||
|
model = model_class(config=config, num_labels=self.num_labels, output_attentions=True)
|
||||||
|
else:
|
||||||
|
model = model_class(config=config, output_attentions=True)
|
||||||
|
model.eval()
|
||||||
|
output = model(input_ids, token_type_ids, input_mask)
|
||||||
|
attentions = output[0]
|
||||||
|
self.parent.assertEqual(len(attentions), self.num_hidden_layers)
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(attentions[0].size()),
|
||||||
|
[self.batch_size, self.num_attention_heads, self.seq_length, self.seq_length])
|
||||||
|
|
||||||
|
|
||||||
def test_default(self):
|
def test_default(self):
|
||||||
self.run_tester(BertModelTest.BertModelTester(self))
|
self.run_tester(BertModelTest.BertModelTester(self))
|
||||||
|
|
||||||
@@ -300,6 +347,12 @@ class BertModelTest(unittest.TestCase):
|
|||||||
tester.check_bert_for_token_classification_output(output_result)
|
tester.check_bert_for_token_classification_output(output_result)
|
||||||
tester.check_loss_output(output_result)
|
tester.check_loss_output(output_result)
|
||||||
|
|
||||||
|
output_result = tester.create_bert_for_multiple_choice(*config_and_inputs)
|
||||||
|
tester.check_bert_for_multiple_choice(output_result)
|
||||||
|
tester.check_loss_output(output_result)
|
||||||
|
|
||||||
|
tester.create_and_check_bert_for_attentions(*config_and_inputs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
|
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
|
||||||
"""Creates a random int32 tensor of the shape within the vocab size."""
|
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||||
|
|||||||
Reference in New Issue
Block a user