From 778a263f09537e0d3667516c1fa674c9d331bc76 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Tue, 27 Aug 2019 22:28:42 -0400 Subject: [PATCH] GilBert added to AutoModels --- pytorch_transformers/modeling_auto.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pytorch_transformers/modeling_auto.py b/pytorch_transformers/modeling_auto.py index 516107c40b..2d28a6017f 100644 --- a/pytorch_transformers/modeling_auto.py +++ b/pytorch_transformers/modeling_auto.py @@ -30,6 +30,7 @@ from .modeling_transfo_xl import TransfoXLConfig, TransfoXLModel from .modeling_xlnet import XLNetConfig, XLNetModel from .modeling_xlm import XLMConfig, XLMModel from .modeling_roberta import RobertaConfig, RobertaModel +from .modeling_dilbert import DilBertConfig, DilBertModel from .modeling_utils import PreTrainedModel, SequenceSummary @@ -110,7 +111,9 @@ class AutoConfig(object): assert unused_kwargs == {'foo': False} """ - if 'roberta' in pretrained_model_name_or_path: + if 'dilbert' in pretrained_model_name_or_path: + return DilBertconfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + elif 'roberta' in pretrained_model_name_or_path: return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) elif 'bert' in pretrained_model_name_or_path: return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) @@ -225,7 +228,9 @@ class AutoModel(object): model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) """ - if 'roberta' in pretrained_model_name_or_path: + if 'dilbert' in pretrained_model_name_or_path: + return DilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + elif 'roberta' in pretrained_model_name_or_path: return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) elif 'bert' in pretrained_model_name_or_path: return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)