GilBert added to AutoModels
This commit is contained in:
@@ -30,6 +30,7 @@ from .modeling_transfo_xl import TransfoXLConfig, TransfoXLModel
|
|||||||
from .modeling_xlnet import XLNetConfig, XLNetModel
|
from .modeling_xlnet import XLNetConfig, XLNetModel
|
||||||
from .modeling_xlm import XLMConfig, XLMModel
|
from .modeling_xlm import XLMConfig, XLMModel
|
||||||
from .modeling_roberta import RobertaConfig, RobertaModel
|
from .modeling_roberta import RobertaConfig, RobertaModel
|
||||||
|
from .modeling_dilbert import DilBertConfig, DilBertModel
|
||||||
|
|
||||||
from .modeling_utils import PreTrainedModel, SequenceSummary
|
from .modeling_utils import PreTrainedModel, SequenceSummary
|
||||||
|
|
||||||
@@ -110,7 +111,9 @@ class AutoConfig(object):
|
|||||||
assert unused_kwargs == {'foo': False}
|
assert unused_kwargs == {'foo': False}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if 'roberta' in pretrained_model_name_or_path:
|
if 'dilbert' in pretrained_model_name_or_path:
|
||||||
|
return DilBertconfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
elif 'roberta' in pretrained_model_name_or_path:
|
||||||
return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
elif 'bert' in pretrained_model_name_or_path:
|
elif 'bert' in pretrained_model_name_or_path:
|
||||||
return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
@@ -225,7 +228,9 @@ class AutoModel(object):
|
|||||||
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if 'roberta' in pretrained_model_name_or_path:
|
if 'dilbert' in pretrained_model_name_or_path:
|
||||||
|
return DilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
elif 'roberta' in pretrained_model_name_or_path:
|
||||||
return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'bert' in pretrained_model_name_or_path:
|
elif 'bert' in pretrained_model_name_or_path:
|
||||||
return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user