fixing model to add torchscript, embedding resizing, head pruning and masking + tests

This commit is contained in:
thomwolf
2019-08-28 13:22:45 +02:00
parent 62df4ba59a
commit c9bce1811c
3 changed files with 253 additions and 138 deletions

View File

@@ -21,7 +21,7 @@ import shutil
import pytest
from pytorch_transformers import (DilBertConfig, DilBertModel, DilBertForMaskedLM,
DilBertForQuestionAnswering, DilBertForSequenceClassification)
DilBertForQuestionAnswering, DilBertForSequenceClassification)
from pytorch_transformers.modeling_dilbert import DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor)
@@ -31,10 +31,10 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (DilBertModel, DilBertForMaskedLM, DilBertForQuestionAnswering,
DilBertForSequenceClassification)
test_pruning = False
test_torchscript = False
test_resize_embeddings = False
test_head_masking = False
test_pruning = True
test_torchscript = True
test_resize_embeddings = True
test_head_masking = True
class DilBertModelTester(object):
@@ -122,22 +122,20 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
def create_and_check_dilbert_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = DilBertModel(config=config)
model.eval()
sequence_output, pooled_output = model(input_ids, input_mask)
sequence_output, pooled_output = model(input_ids)
(sequence_output,) = model(input_ids, input_mask)
(sequence_output,) = model(input_ids)
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
}
self.parent.assertListEqual(
list(result["sequence_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_dilbert_for_masked_lm(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = DilBertForMaskedLM(config=config)
model.eval()
loss, prediction_scores = model(input_ids, input_mask, token_labels)
loss, prediction_scores = model(input_ids, attention_mask=input_mask, masked_lm_labels=token_labels)
result = {
"loss": loss,
"prediction_scores": prediction_scores,