dilbert -> distilbert
This commit is contained in:
@@ -20,23 +20,23 @@ import unittest
|
||||
import shutil
|
||||
import pytest
|
||||
|
||||
from pytorch_transformers import (DilBertConfig, DilBertModel, DilBertForMaskedLM,
|
||||
DilBertForQuestionAnswering, DilBertForSequenceClassification)
|
||||
from pytorch_transformers.modeling_dilbert import DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from pytorch_transformers import (DistilBertConfig, DistilBertModel, DistilBertForMaskedLM,
|
||||
DistilBertForQuestionAnswering, DistilBertForSequenceClassification)
|
||||
from pytorch_transformers.modeling_distilbert import DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor)
|
||||
|
||||
|
||||
class DilBertModelTest(CommonTestCases.CommonModelTester):
|
||||
class DistilBertModelTest(CommonTestCases.CommonModelTester):
|
||||
|
||||
all_model_classes = (DilBertModel, DilBertForMaskedLM, DilBertForQuestionAnswering,
|
||||
DilBertForSequenceClassification)
|
||||
all_model_classes = (DistilBertModel, DistilBertForMaskedLM, DistilBertForQuestionAnswering,
|
||||
DistilBertForSequenceClassification)
|
||||
test_pruning = True
|
||||
test_torchscript = True
|
||||
test_resize_embeddings = True
|
||||
test_head_masking = True
|
||||
|
||||
class DilBertModelTester(object):
|
||||
class DistilBertModelTester(object):
|
||||
|
||||
def __init__(self,
|
||||
parent,
|
||||
@@ -100,7 +100,7 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = DilBertConfig(
|
||||
config = DistilBertConfig(
|
||||
vocab_size_or_config_json_file=self.vocab_size,
|
||||
dim=self.hidden_size,
|
||||
n_layers=self.num_hidden_layers,
|
||||
@@ -119,8 +119,8 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
|
||||
list(result["loss"].size()),
|
||||
[])
|
||||
|
||||
def create_and_check_dilbert_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = DilBertModel(config=config)
|
||||
def create_and_check_distilbert_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = DistilBertModel(config=config)
|
||||
model.eval()
|
||||
(sequence_output,) = model(input_ids, input_mask)
|
||||
(sequence_output,) = model(input_ids)
|
||||
@@ -132,8 +132,8 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
|
||||
list(result["sequence_output"].size()),
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
|
||||
def create_and_check_dilbert_for_masked_lm(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = DilBertForMaskedLM(config=config)
|
||||
def create_and_check_distilbert_for_masked_lm(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = DistilBertForMaskedLM(config=config)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(input_ids, attention_mask=input_mask, masked_lm_labels=token_labels)
|
||||
result = {
|
||||
@@ -145,8 +145,8 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
|
||||
[self.batch_size, self.seq_length, self.vocab_size])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_dilbert_for_question_answering(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = DilBertForQuestionAnswering(config=config)
|
||||
def create_and_check_distilbert_for_question_answering(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
model = DistilBertForQuestionAnswering(config=config)
|
||||
model.eval()
|
||||
loss, start_logits, end_logits = model(input_ids, input_mask, sequence_labels, sequence_labels)
|
||||
result = {
|
||||
@@ -162,9 +162,9 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
|
||||
[self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_dilbert_for_sequence_classification(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
def create_and_check_distilbert_for_sequence_classification(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
||||
config.num_labels = self.num_labels
|
||||
model = DilBertForSequenceClassification(config)
|
||||
model = DistilBertForSequenceClassification(config)
|
||||
model.eval()
|
||||
loss, logits = model(input_ids, input_mask, sequence_labels)
|
||||
result = {
|
||||
@@ -183,33 +183,33 @@ class DilBertModelTest(CommonTestCases.CommonModelTester):
|
||||
return config, inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = DilBertModelTest.DilBertModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=DilBertConfig, dim=37)
|
||||
self.model_tester = DistilBertModelTest.DistilBertModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=DistilBertConfig, dim=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_dilbert_model(self):
|
||||
def test_distilbert_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_dilbert_model(*config_and_inputs)
|
||||
self.model_tester.create_and_check_distilbert_model(*config_and_inputs)
|
||||
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_dilbert_for_masked_lm(*config_and_inputs)
|
||||
self.model_tester.create_and_check_distilbert_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_dilbert_for_question_answering(*config_and_inputs)
|
||||
self.model_tester.create_and_check_distilbert_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_for_sequence_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_dilbert_for_sequence_classification(*config_and_inputs)
|
||||
self.model_tester.create_and_check_distilbert_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
# @pytest.mark.slow
|
||||
# def test_model_from_pretrained(self):
|
||||
# cache_dir = "/tmp/pytorch_transformers_test/"
|
||||
# for model_name in list(DILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
# model = DilBertModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
# model = DistilBertModel.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
# shutil.rmtree(cache_dir)
|
||||
# self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@@ -18,20 +18,20 @@ import os
|
||||
import unittest
|
||||
from io import open
|
||||
|
||||
from pytorch_transformers.tokenization_dilbert import (DilBertTokenizer)
|
||||
from pytorch_transformers.tokenization_distilbert import (DistilBertTokenizer)
|
||||
|
||||
from .tokenization_tests_commons import CommonTestCases
|
||||
from .tokenization_bert_test import BertTokenizationTest
|
||||
|
||||
class DilBertTokenizationTest(BertTokenizationTest):
|
||||
class DistilBertTokenizationTest(BertTokenizationTest):
|
||||
|
||||
tokenizer_class = DilBertTokenizer
|
||||
tokenizer_class = DistilBertTokenizer
|
||||
|
||||
def get_tokenizer(self):
|
||||
return DilBertTokenizer.from_pretrained(self.tmpdirname)
|
||||
return DistilBertTokenizer.from_pretrained(self.tmpdirname)
|
||||
|
||||
def test_sequence_builders(self):
|
||||
tokenizer = DilBertTokenizer.from_pretrained("dilbert-base-uncased")
|
||||
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
||||
|
||||
text = tokenizer.encode("sequence builders")
|
||||
text_2 = tokenizer.encode("multi-sequence build")
|
||||
|
||||
Reference in New Issue
Block a user