From ab05280666c9e1cfbbb23122825f3a41b7ff82c3 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Fri, 16 Aug 2019 09:53:26 -0400 Subject: [PATCH] Order of strings in AutoModel/AutoTokenizer updated. --- pytorch_transformers/modeling_auto.py | 12 ++++++------ pytorch_transformers/tokenization_auto.py | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pytorch_transformers/modeling_auto.py b/pytorch_transformers/modeling_auto.py index 7c96b7a287..516107c40b 100644 --- a/pytorch_transformers/modeling_auto.py +++ b/pytorch_transformers/modeling_auto.py @@ -110,7 +110,9 @@ class AutoConfig(object): assert unused_kwargs == {'foo': False} """ - if 'bert' in pretrained_model_name_or_path: + if 'roberta' in pretrained_model_name_or_path: + return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + elif 'bert' in pretrained_model_name_or_path: return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) elif 'openai-gpt' in pretrained_model_name_or_path: return OpenAIGPTConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) @@ -122,8 +124,6 @@ class AutoConfig(object): return XLNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) elif 'xlm' in pretrained_model_name_or_path: return XLMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - elif 'roberta' in pretrained_model_name_or_path: - return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) raise ValueError("Unrecognized model identifier in {}. Should contains one of " "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " @@ -225,7 +225,9 @@ class AutoModel(object): model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) """ - if 'bert' in pretrained_model_name_or_path: + if 'roberta' in pretrained_model_name_or_path: + return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + elif 'bert' in pretrained_model_name_or_path: return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) elif 'openai-gpt' in pretrained_model_name_or_path: return OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) @@ -237,8 +239,6 @@ class AutoModel(object): return XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) elif 'xlm' in pretrained_model_name_or_path: return XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - elif 'roberta' in pretrained_model_name_or_path: - return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) raise ValueError("Unrecognized model identifier in {}. Should contains one of " "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " diff --git a/pytorch_transformers/tokenization_auto.py b/pytorch_transformers/tokenization_auto.py index adb8f87cd7..b4b6336952 100644 --- a/pytorch_transformers/tokenization_auto.py +++ b/pytorch_transformers/tokenization_auto.py @@ -85,7 +85,9 @@ class AutoTokenizer(object): config = AutoTokenizer.from_pretrained('./test/bert_saved_model/') # E.g. tokenizer was saved using `save_pretrained('./test/saved_model/')` """ - if 'bert' in pretrained_model_name_or_path: + if 'roberta' in pretrained_model_name_or_path: + return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + elif 'bert' in pretrained_model_name_or_path: return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) elif 'openai-gpt' in pretrained_model_name_or_path: return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) @@ -97,8 +99,6 @@ class AutoTokenizer(object): return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) elif 'xlm' in pretrained_model_name_or_path: return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) - elif 'roberta' in pretrained_model_name_or_path: - return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) raise ValueError("Unrecognized model identifier in {}. Should contains one of " "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "