dilbert -> distilbert

This commit is contained in:
thomwolf
2019-08-28 13:59:42 +02:00
parent c9bce1811c
commit 912a377e90
15 changed files with 144 additions and 144 deletions

View File

@@ -20,23 +20,23 @@ import unittest
import shutil
import pytest
from pytorch_transformers import (DilBertConfig, DilBertModel, DilBertForMaskedLM,
DilBertForQuestionAnswering, DilBertForSequenceClassification)
from pytorch_transformers.modeling_dilbert import DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
from pytorch_transformers import (DistilBertConfig, DistilBertModel, DistilBertForMaskedLM,
DistilBertForQuestionAnswering, DistilBertForSequenceClassification)
from pytorch_transformers.modeling_distilbert import DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor)
class DilBertModelTest(CommonTestCases.CommonModelTester):
class DistilBertModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (DilBertModel, DilBertForMaskedLM, DilBertForQuestionAnswering,
DilBertForSequenceClassification)
all_model_classes = (DistilBertModel, DistilBertForMaskedLM, DistilBertForQuestionAnswering,
DistilBertForSequenceClassification)
test_pruning = True
test_torchscript = True
test_resize_embeddings = True
test_head_masking = True
class DilBertModelTester(object):
class DistilBertModelTester(object):
def __init__(self,
parent,
@@ -100,7 +100,7 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = DilBertConfig(
config = DistilBertConfig(
vocab_size_or_config_json_file=self.vocab_size,
dim=self.hidden_size,
n_layers=self.num_hidden_layers,
@@ -119,8 +119,8 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
list(result["loss"].size()),
[])
def create_and_check_dilbert_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = DilBertModel(config=config)
def create_and_check_distilbert_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = DistilBertModel(config=config)
model.eval()
(sequence_output,) = model(input_ids, input_mask)
(sequence_output,) = model(input_ids)
@@ -132,8 +132,8 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
list(result["sequence_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
def create_and_check_dilbert_for_masked_lm(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = DilBertForMaskedLM(config=config)
def create_and_check_distilbert_for_masked_lm(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = DistilBertForMaskedLM(config=config)
model.eval()
loss, prediction_scores = model(input_ids, attention_mask=input_mask, masked_lm_labels=token_labels)
result = {
@@ -145,8 +145,8 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
[self.batch_size, self.seq_length, self.vocab_size])
self.check_loss_output(result)
def create_and_check_dilbert_for_question_answering(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = DilBertForQuestionAnswering(config=config)
def create_and_check_distilbert_for_question_answering(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = DistilBertForQuestionAnswering(config=config)
model.eval()
loss, start_logits, end_logits = model(input_ids, input_mask, sequence_labels, sequence_labels)
result = {
@@ -162,9 +162,9 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
[self.batch_size, self.seq_length])
self.check_loss_output(result)
def create_and_check_dilbert_for_sequence_classification(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
def create_and_check_distilbert_for_sequence_classification(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
config.num_labels = self.num_labels
model = DilBertForSequenceClassification(config)
model = DistilBertForSequenceClassification(config)
model.eval()
loss, logits = model(input_ids, input_mask, sequence_labels)
result = {
@@ -183,33 +183,33 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
return config, inputs_dict
def setUp(self):
self.model_tester = DilBertModelTest.DilBertModelTester(self)
self.config_tester = ConfigTester(self, config_class=DilBertConfig, dim=37)
self.model_tester = DistilBertModelTest.DistilBertModelTester(self)
self.config_tester = ConfigTester(self, config_class=DistilBertConfig, dim=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_dilbert_model(self):
def test_distilbert_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_dilbert_model(*config_and_inputs)
self.model_tester.create_and_check_distilbert_model(*config_and_inputs)
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_dilbert_for_masked_lm(*config_and_inputs)
self.model_tester.create_and_check_distilbert_for_masked_lm(*config_and_inputs)
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_dilbert_for_question_answering(*config_and_inputs)
self.model_tester.create_and_check_distilbert_for_question_answering(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_dilbert_for_sequence_classification(*config_and_inputs)
self.model_tester.create_and_check_distilbert_for_sequence_classification(*config_and_inputs)
# @pytest.mark.slow
# def test_model_from_pretrained(self):
# cache_dir = "/tmp/pytorch_transformers_test/"
# for model_name in list(DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# model = DilBertModel.from_pretrained(model_name, cache_dir=cache_dir)
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
# model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir)
# shutil.rmtree(cache_dir)
# self.assertIsNotNone(model)

View File

@@ -18,20 +18,20 @@ import os
import unittest
from io import open
from pytorch_transformers.tokenization_dilbert import (DilBertTokenizer)
from pytorch_transformers.tokenization_distilbert import (DistilBertTokenizer)
from .tokenization_tests_commons import CommonTestCases
from .tokenization_bert_test import BertTokenizationTest
class DilBertTokenizationTest(BertTokenizationTest):
class DistilBertTokenizationTest(BertTokenizationTest):
tokenizer_class = DilBertTokenizer
tokenizer_class = DistilBertTokenizer
def get_tokenizer(self):
return DilBertTokenizer.from_pretrained(self.tmpdirname)
return DistilBertTokenizer.from_pretrained(self.tmpdirname)
def test_sequence_builders(self):
tokenizer = DilBertTokenizer.from_pretrained("dilbert-base-uncased")
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
text = tokenizer.encode("sequence builders")
text_2 = tokenizer.encode("multi-sequence build")