More AutoConfig tests
This commit is contained in:
@@ -20,6 +20,8 @@ from transformers.configuration_auto import AutoConfig
|
|||||||
from transformers.configuration_bert import BertConfig
|
from transformers.configuration_bert import BertConfig
|
||||||
from transformers.configuration_roberta import RobertaConfig
|
from transformers.configuration_roberta import RobertaConfig
|
||||||
|
|
||||||
|
from .utils import DUMMY_UNKWOWN_IDENTIFIER
|
||||||
|
|
||||||
|
|
||||||
SAMPLE_ROBERTA_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")
|
SAMPLE_ROBERTA_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")
|
||||||
|
|
||||||
@@ -29,10 +31,14 @@ class AutoConfigTest(unittest.TestCase):
|
|||||||
config = AutoConfig.from_pretrained("bert-base-uncased")
|
config = AutoConfig.from_pretrained("bert-base-uncased")
|
||||||
self.assertIsInstance(config, BertConfig)
|
self.assertIsInstance(config, BertConfig)
|
||||||
|
|
||||||
def test_config_from_model_type(self):
|
def test_config_model_type_from_local_file(self):
|
||||||
config = AutoConfig.from_pretrained(SAMPLE_ROBERTA_CONFIG)
|
config = AutoConfig.from_pretrained(SAMPLE_ROBERTA_CONFIG)
|
||||||
self.assertIsInstance(config, RobertaConfig)
|
self.assertIsInstance(config, RobertaConfig)
|
||||||
|
|
||||||
|
def test_config_model_type_from_model_identifier(self):
|
||||||
|
config = AutoConfig.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
|
||||||
|
self.assertIsInstance(config, RobertaConfig)
|
||||||
|
|
||||||
def test_config_for_model_str(self):
|
def test_config_for_model_str(self):
|
||||||
config = AutoConfig.for_model("roberta")
|
config = AutoConfig.for_model("roberta")
|
||||||
self.assertIsInstance(config, RobertaConfig)
|
self.assertIsInstance(config, RobertaConfig)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import unittest
|
|||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
|
||||||
from .utils import SMALL_MODEL_IDENTIFIER, require_torch, slow
|
from .utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_torch, slow
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -30,6 +30,7 @@ if is_torch_available():
|
|||||||
BertModel,
|
BertModel,
|
||||||
AutoModelWithLMHead,
|
AutoModelWithLMHead,
|
||||||
BertForMaskedLM,
|
BertForMaskedLM,
|
||||||
|
RobertaForMaskedLM,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
BertForSequenceClassification,
|
BertForSequenceClassification,
|
||||||
AutoModelForQuestionAnswering,
|
AutoModelForQuestionAnswering,
|
||||||
@@ -102,3 +103,10 @@ class AutoModelTest(unittest.TestCase):
|
|||||||
self.assertIsInstance(model, BertForMaskedLM)
|
self.assertIsInstance(model, BertForMaskedLM)
|
||||||
self.assertEqual(model.num_parameters(), 14830)
|
self.assertEqual(model.num_parameters(), 14830)
|
||||||
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
|
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
|
||||||
|
|
||||||
|
def test_from_identifier_from_model_type(self):
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
model = AutoModelWithLMHead.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
|
||||||
|
self.assertIsInstance(model, RobertaForMaskedLM)
|
||||||
|
self.assertEqual(model.num_parameters(), 14830)
|
||||||
|
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import unittest
|
|||||||
|
|
||||||
from transformers import is_tf_available
|
from transformers import is_tf_available
|
||||||
|
|
||||||
from .utils import SMALL_MODEL_IDENTIFIER, require_tf, slow
|
from .utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_tf, slow
|
||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
@@ -30,6 +30,7 @@ if is_tf_available():
|
|||||||
TFBertModel,
|
TFBertModel,
|
||||||
TFAutoModelWithLMHead,
|
TFAutoModelWithLMHead,
|
||||||
TFBertForMaskedLM,
|
TFBertForMaskedLM,
|
||||||
|
TFRobertaForMaskedLM,
|
||||||
TFAutoModelForSequenceClassification,
|
TFAutoModelForSequenceClassification,
|
||||||
TFBertForSequenceClassification,
|
TFBertForSequenceClassification,
|
||||||
TFAutoModelForQuestionAnswering,
|
TFAutoModelForQuestionAnswering,
|
||||||
@@ -101,3 +102,10 @@ class TFAutoModelTest(unittest.TestCase):
|
|||||||
self.assertIsInstance(model, TFBertForMaskedLM)
|
self.assertIsInstance(model, TFBertForMaskedLM)
|
||||||
self.assertEqual(model.num_parameters(), 14830)
|
self.assertEqual(model.num_parameters(), 14830)
|
||||||
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
|
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
|
||||||
|
|
||||||
|
def test_from_identifier_from_model_type(self):
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
model = TFAutoModelWithLMHead.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
|
||||||
|
self.assertIsInstance(model, TFRobertaForMaskedLM)
|
||||||
|
self.assertEqual(model.num_parameters(), 14830)
|
||||||
|
self.assertEqual(model.num_parameters(only_trainable=True), 14830)
|
||||||
|
|||||||
@@ -23,9 +23,10 @@ from transformers import (
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
BertTokenizer,
|
BertTokenizer,
|
||||||
GPT2Tokenizer,
|
GPT2Tokenizer,
|
||||||
|
RobertaTokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .utils import SMALL_MODEL_IDENTIFIER, slow # noqa: F401
|
from .utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, slow # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
class AutoTokenizerTest(unittest.TestCase):
|
class AutoTokenizerTest(unittest.TestCase):
|
||||||
@@ -49,3 +50,9 @@ class AutoTokenizerTest(unittest.TestCase):
|
|||||||
tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
tokenizer = AutoTokenizer.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
||||||
self.assertIsInstance(tokenizer, BertTokenizer)
|
self.assertIsInstance(tokenizer, BertTokenizer)
|
||||||
self.assertEqual(len(tokenizer), 12)
|
self.assertEqual(len(tokenizer), 12)
|
||||||
|
|
||||||
|
def test_tokenizer_from_model_type(self):
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(DUMMY_UNKWOWN_IDENTIFIER)
|
||||||
|
self.assertIsInstance(tokenizer, RobertaTokenizer)
|
||||||
|
self.assertEqual(len(tokenizer), 20)
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ from transformers.file_utils import _tf_available, _torch_available
|
|||||||
CACHE_DIR = os.path.join(tempfile.gettempdir(), "transformers_test")
|
CACHE_DIR = os.path.join(tempfile.gettempdir(), "transformers_test")
|
||||||
|
|
||||||
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
|
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
|
||||||
|
DUMMY_UNKWOWN_IDENTIFIER = "julien-c/dummy-unknown"
|
||||||
|
# Used to test Auto{Config, Model, Tokenizer} model_type detection.
|
||||||
|
|
||||||
|
|
||||||
def parse_flag_from_env(key, default=False):
|
def parse_flag_from_env(key, default=False):
|
||||||
|
|||||||
Reference in New Issue
Block a user