From fa963ecc59a1dea59c2d0e952b2c4483e1828176 Mon Sep 17 00:00:00 2001 From: Evpok Padding Date: Thu, 21 Nov 2019 11:03:12 +0100 Subject: [PATCH] =?UTF-8?q?if=E2=86=92elif?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- transformers/modeling_auto.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformers/modeling_auto.py b/transformers/modeling_auto.py index ce36f6dc4a..5866420001 100644 --- a/transformers/modeling_auto.py +++ b/transformers/modeling_auto.py @@ -271,7 +271,7 @@ class AutoModelWithLMHead(object): """ if 'distilbert' in pretrained_model_name_or_path: return DistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - if 'camembert' in pretrained_model_name_or_path: + elif 'camembert' in pretrained_model_name_or_path: return CamembertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) elif 'roberta' in pretrained_model_name_or_path: return RobertaForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) @@ -394,7 +394,7 @@ class AutoModelForSequenceClassification(object): """ if 'distilbert' in pretrained_model_name_or_path: return DistilBertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - if 'camembert' in pretrained_model_name_or_path: + elif 'camembert' in pretrained_model_name_or_path: return CamembertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) elif 'roberta' in pretrained_model_name_or_path: return RobertaForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)