dilbert -> distilbert

This commit is contained in:
thomwolf
2019-08-28 13:59:42 +02:00
parent c9bce1811c
commit 912a377e90
15 changed files with 144 additions and 144 deletions

View File

@@ -30,7 +30,7 @@ from .modeling_transfo_xl import TransfoXLConfig, TransfoXLModel
from .modeling_xlnet import XLNetConfig, XLNetModel
from .modeling_xlm import XLMConfig, XLMModel
from .modeling_roberta import RobertaConfig, RobertaModel
from .modeling_dilbert import DilBertConfig, DilBertModel
from .modeling_distilbert import DistilBertConfig, DistilBertModel
from .modeling_utils import PreTrainedModel, SequenceSummary
@@ -111,8 +111,8 @@ class AutoConfig(object):
assert unused_kwargs == {'foo': False}
"""
if 'dilbert' in pretrained_model_name_or_path:
return DilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if 'distilbert' in pretrained_model_name_or_path:
return DistilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'roberta' in pretrained_model_name_or_path:
return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'bert' in pretrained_model_name_or_path:
@@ -228,8 +228,8 @@ class AutoModel(object):
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
if 'dilbert' in pretrained_model_name_or_path:
return DilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
if 'distilbert' in pretrained_model_name_or_path:
return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'roberta' in pretrained_model_name_or_path:
return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
elif 'bert' in pretrained_model_name_or_path: