From 34a3c25a3068ab5cdbecb08ddf2866f1209fd2dd Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Wed, 22 Jan 2020 17:50:24 -0500 Subject: [PATCH] Fix for XLMRobertaConfig inherits from RobertaConfig hat/tip @stefan-it --- src/transformers/modeling_auto.py | 8 ++++---- src/transformers/tokenization_auto.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index f7abd3157c..d18af9f0a5 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -130,8 +130,8 @@ MODEL_MAPPING = OrderedDict( (DistilBertConfig, DistilBertModel), (AlbertConfig, AlbertModel), (CamembertConfig, CamembertModel), - (RobertaConfig, RobertaModel), (XLMRobertaConfig, XLMRobertaModel), + (RobertaConfig, RobertaModel), (BertConfig, BertModel), (OpenAIGPTConfig, OpenAIGPTModel), (GPT2Config, GPT2Model), @@ -148,8 +148,8 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict( (DistilBertConfig, DistilBertForMaskedLM), (AlbertConfig, AlbertForMaskedLM), (CamembertConfig, CamembertForMaskedLM), - (RobertaConfig, RobertaForMaskedLM), (XLMRobertaConfig, XLMRobertaForMaskedLM), + (RobertaConfig, RobertaForMaskedLM), (BertConfig, BertForMaskedLM), (OpenAIGPTConfig, OpenAIGPTLMHeadModel), (GPT2Config, GPT2LMHeadModel), @@ -165,8 +165,8 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( (DistilBertConfig, DistilBertForSequenceClassification), (AlbertConfig, AlbertForSequenceClassification), (CamembertConfig, CamembertForSequenceClassification), - (RobertaConfig, RobertaForSequenceClassification), (XLMRobertaConfig, XLMRobertaForSequenceClassification), + (RobertaConfig, RobertaForSequenceClassification), (BertConfig, BertForSequenceClassification), (XLNetConfig, XLNetForSequenceClassification), (XLMConfig, XLMForSequenceClassification), @@ -187,8 +187,8 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict( [ (DistilBertConfig, DistilBertForTokenClassification), (CamembertConfig, CamembertForTokenClassification), - (RobertaConfig, RobertaForTokenClassification), (XLMRobertaConfig, XLMRobertaForTokenClassification), + (RobertaConfig, RobertaForTokenClassification), (BertConfig, BertForTokenClassification), (XLNetConfig, XLNetForTokenClassification), ] diff --git a/src/transformers/tokenization_auto.py b/src/transformers/tokenization_auto.py index 8523331ed6..368833fa58 100644 --- a/src/transformers/tokenization_auto.py +++ b/src/transformers/tokenization_auto.py @@ -60,8 +60,8 @@ TOKENIZER_MAPPING = OrderedDict( (DistilBertConfig, DistilBertTokenizer), (AlbertConfig, AlbertTokenizer), (CamembertConfig, CamembertTokenizer), - (RobertaConfig, RobertaTokenizer), (XLMRobertaConfig, XLMRobertaTokenizer), + (RobertaConfig, RobertaTokenizer), (BertConfig, BertTokenizer), (OpenAIGPTConfig, OpenAIGPTTokenizer), (GPT2Config, GPT2Tokenizer),