fixing model to add torchscript, embedding resizing, head pruning and masking + tests
This commit is contained in:
@@ -21,7 +21,7 @@ import shutil
|
||||
import pytest
|
||||
|
||||
from pytorch_transformers import (DilBertConfig, DilBertModel, DilBertForMaskedLM,
|
||||
DilBertForQuestionAnswering, DilBertForSequenceClassification)
|
||||
DilBertForQuestionAnswering, DilBertForSequenceClassification)
|
||||
from pytorch_transformers.modeling_dilbert import DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor)
|
||||
@@ -31,10 +31,10 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
all_model_classes = (DilBertModel, DilBertForMaskedLM, DilBertForQuestionAnswering,
|
||||
DilBertForSequenceClassification)
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_pruning = True
|
||||
test_torchscript = True
|
||||
test_resize_embeddings = True
|
||||
test_head_masking = True
|
||||
|
||||
class DilBertModelTester(object):
|
||||
|
||||
@@ -122,22 +122,20 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
|
||||
def create_and_check_dilbert_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = DilBertModel(config=config)
|
||||
model.eval()
|
||||
sequence_output, pooled_output = model(input_ids, input_mask)
|
||||
sequence_output, pooled_output = model(input_ids)
|
||||
(sequence_output,) = model(input_ids, input_mask)
|
||||
(sequence_output,) = model(input_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
"pooled_output": pooled_output,
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()),
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
|
||||
|
||||
def create_and_check_dilbert_for_masked_lm(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = DilBertForMaskedLM(config=config)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(input_ids, input_mask, token_labels)
|
||||
loss, prediction_scores = model(input_ids, attention_mask=input_mask, masked_lm_labels=token_labels)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
|
||||
Reference in New Issue
Block a user