[AutoModel] Split AutoModelWithLMHead into clm, mlm, encoder-decoder (#4933)
* first commit * add new auto models * better naming * fix bert automodel * fix automodel for pretraining * add models to init * fix name typo * fix typo * better naming * future warning instead of depreciation warning
This commit is contained in:
committed by
GitHub
parent
5620033115
commit
86578bb04c
@@ -26,13 +26,20 @@ if is_torch_available():
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
BertConfig,
|
||||
GPT2Config,
|
||||
T5Config,
|
||||
AutoModel,
|
||||
BertModel,
|
||||
AutoModelForPreTraining,
|
||||
BertForPreTraining,
|
||||
AutoModelForCausalLM,
|
||||
GPT2LMHeadModel,
|
||||
AutoModelWithLMHead,
|
||||
AutoModelForMaskedLM,
|
||||
BertForMaskedLM,
|
||||
RobertaForMaskedLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
T5ForConditionalGeneration,
|
||||
AutoModelForSequenceClassification,
|
||||
BertForSequenceClassification,
|
||||
AutoModelForQuestionAnswering,
|
||||
@@ -41,6 +48,8 @@ if is_torch_available():
|
||||
BertForTokenClassification,
|
||||
)
|
||||
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
from transformers.modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
from transformers.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
from transformers.modeling_auto import (
|
||||
MODEL_MAPPING,
|
||||
MODEL_FOR_PRETRAINING_MAPPING,
|
||||
@@ -48,6 +57,9 @@ if is_torch_available():
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
MODEL_WITH_LM_HEAD_MAPPING,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
)
|
||||
|
||||
|
||||
@@ -97,6 +109,45 @@ class AutoModelTest(unittest.TestCase):
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, BertForMaskedLM)
|
||||
|
||||
@slow
|
||||
def test_model_for_causal_lm(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, GPT2Config)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
model, loading_info = AutoModelForCausalLM.from_pretrained(model_name, output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, GPT2LMHeadModel)
|
||||
|
||||
@slow
|
||||
def test_model_for_masked_lm(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, BertConfig)
|
||||
|
||||
model = AutoModelForMaskedLM.from_pretrained(model_name)
|
||||
model, loading_info = AutoModelForMaskedLM.from_pretrained(model_name, output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, BertForMaskedLM)
|
||||
|
||||
@slow
|
||||
def test_model_for_encoder_decoder_lm(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in T5_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, T5Config)
|
||||
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
||||
model, loading_info = AutoModelForSeq2SeqLM.from_pretrained(model_name, output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, T5ForConditionalGeneration)
|
||||
|
||||
@slow
|
||||
def test_sequence_classification_model_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -163,6 +214,9 @@ class AutoModelTest(unittest.TestCase):
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
MODEL_WITH_LM_HEAD_MAPPING,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
)
|
||||
|
||||
for mapping in mappings:
|
||||
|
||||
@@ -27,6 +27,7 @@ if is_torch_available():
|
||||
from transformers import (
|
||||
BertConfig,
|
||||
BertModel,
|
||||
BertLMHeadModel,
|
||||
BertForMaskedLM,
|
||||
BertForNextSentencePrediction,
|
||||
BertForPreTraining,
|
||||
@@ -35,7 +36,7 @@ if is_torch_available():
|
||||
BertForTokenClassification,
|
||||
BertForMultipleChoice,
|
||||
)
|
||||
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST, BertLMHeadModel
|
||||
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
class BertModelTester:
|
||||
|
||||
Reference in New Issue
Block a user