From cf8a70bf68713648b4f8b609b118414d3c4e33dc Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Sat, 11 Jan 2020 03:43:57 +0000 Subject: [PATCH] More AutoConfig tests --- tests/test_configuration_auto.py | 8 +++++++- tests/test_modeling_auto.py | 10 +++++++++- tests/test_modeling_tf_auto.py | 10 +++++++++- tests/test_tokenization_auto.py | 9 ++++++++- tests/utils.py | 2 ++ 5 files changed, 35 insertions(+), 4 deletions(-) diff --git a/tests/test_configuration_auto.py b/tests/test_configuration_auto.py index 842732da46..a6f3376d75 100644 --- a/tests/test_configuration_auto.py +++ b/tests/test_configuration_auto.py @@ -20,6 +20,8 @@ from transformers.configuration_auto import AutoConfig from transformers.configuration_bert import BertConfig from transformers.configuration_roberta import RobertaConfig +from .utils import DUMMY_UNKWOWN_IDENTIFIER + SAMPLE_ROBERTA_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json") @@ -29,10 +31,14 @@ class AutoConfigTest(unittest.TestCase): config = AutoConfig.from_pretrained("bert-base-uncased") self.assertIsInstance(config, BertConfig) - def test_config_from_model_type(self): + def test_config_model_type_from_local_file(self): config = AutoConfig.from_pretrained(SAMPLE_ROBERTA_CONFIG) self.assertIsInstance(config, RobertaConfig) + def test_config_model_type_from_model_identifier(self): + config = AutoConfig.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER) + self.assertIsInstance(config, RobertaConfig) + def test_config_for_model_str(self): config = AutoConfig.for_model("roberta") self.assertIsInstance(config, RobertaConfig) diff --git a/tests/test_modeling_auto.py b/tests/test_modeling_auto.py index dcf2526577..435fc04cf0 100644 --- a/tests/test_modeling_auto.py +++ b/tests/test_modeling_auto.py @@ -19,7 +19,7 @@ import unittest from transformers import is_torch_available -from .utils import SMALL_MODEL_IDENTIFIER, require_torch, slow +from .utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_torch, slow if is_torch_available(): @@ -30,6 +30,7 @@ if is_torch_available(): BertModel, AutoModelWithLMHead, BertForMaskedLM, + RobertaForMaskedLM, AutoModelForSequenceClassification, BertForSequenceClassification, AutoModelForQuestionAnswering, @@ -102,3 +103,10 @@ class AutoModelTest(unittest.TestCase): self.assertIsInstance(model, BertForMaskedLM) self.assertEqual(model.num_parameters(), 14830) self.assertEqual(model.num_parameters(only_trainable=True), 14830) + + def test_from_identifier_from_model_type(self): + logging.basicConfig(level=logging.INFO) + model = AutoModelWithLMHead.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER) + self.assertIsInstance(model, RobertaForMaskedLM) + self.assertEqual(model.num_parameters(), 14830) + self.assertEqual(model.num_parameters(only_trainable=True), 14830) diff --git a/tests/test_modeling_tf_auto.py b/tests/test_modeling_tf_auto.py index 56d5f3efbe..e86499572f 100644 --- a/tests/test_modeling_tf_auto.py +++ b/tests/test_modeling_tf_auto.py @@ -19,7 +19,7 @@ import unittest from transformers import is_tf_available -from .utils import SMALL_MODEL_IDENTIFIER, require_tf, slow +from .utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_tf, slow if is_tf_available(): @@ -30,6 +30,7 @@ if is_tf_available(): TFBertModel, TFAutoModelWithLMHead, TFBertForMaskedLM, + TFRobertaForMaskedLM, TFAutoModelForSequenceClassification, TFBertForSequenceClassification, TFAutoModelForQuestionAnswering, @@ -101,3 +102,10 @@ class TFAutoModelTest(unittest.TestCase): self.assertIsInstance(model, TFBertForMaskedLM) self.assertEqual(model.num_parameters(), 14830) self.assertEqual(model.num_parameters(only_trainable=True), 14830) + + def test_from_identifier_from_model_type(self): + logging.basicConfig(level=logging.INFO) + model = TFAutoModelWithLMHead.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER) + self.assertIsInstance(model, TFRobertaForMaskedLM) + self.assertEqual(model.num_parameters(), 14830) + self.assertEqual(model.num_parameters(only_trainable=True), 14830) diff --git a/tests/test_tokenization_auto.py b/tests/test_tokenization_auto.py index e9d23c64bc..75849216bd 100644 --- a/tests/test_tokenization_auto.py +++ b/tests/test_tokenization_auto.py @@ -23,9 +23,10 @@ from transformers import ( AutoTokenizer, BertTokenizer, GPT2Tokenizer, + RobertaTokenizer, ) -from .utils import SMALL_MODEL_IDENTIFIER, slow # noqa: F401 +from .utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, slow # noqa: F401 class AutoTokenizerTest(unittest.TestCase): @@ -49,3 +50,9 @@ class AutoTokenizerTest(unittest.TestCase): tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER) self.assertIsInstance(tokenizer, BertTokenizer) self.assertEqual(len(tokenizer), 12) + + def test_tokenizer_from_model_type(self): + logging.basicConfig(level=logging.INFO) + tokenizer = AutoTokenizer.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER) + self.assertIsInstance(tokenizer, RobertaTokenizer) + self.assertEqual(len(tokenizer), 20) diff --git a/tests/utils.py b/tests/utils.py index 66ff53d6ee..163628d3a7 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -9,6 +9,8 @@ from transformers.file_utils import _tf_available, _torch_available CACHE_DIR = os.path.join(tempfile.gettempdir(), "transformers_test") SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" +DUMMY_UNKWOWN_IDENTIFIER = "julien-c/dummy-unknown" +# Used to test Auto{Config, Model, Tokenizer} model_type detection. def parse_flag_from_env(key, default=False):