clean up PT <=> TF 2.0 conversion and config loading
This commit is contained in:
@@ -32,7 +32,7 @@ from transformers import (load_pytorch_checkpoint_in_tf2_model,
|
||||
TransfoXLConfig, TFTransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, TFDistilBertForSequenceClassification, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
CTRLConfig, TFCTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
AlbertConfig, TFAlbertForMaskedLM, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
T5Config, TFT5WithLMHeadModel, T5_PRETRAINED_CONFIG_ARCHIVE_MAP)
|
||||
@@ -47,7 +47,7 @@ if is_torch_available():
|
||||
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DistilBertForSequenceClassification, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
@@ -59,7 +59,7 @@ else:
|
||||
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
DistilBertForMaskedLM, DistilBertForSequenceClassification, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP) = (
|
||||
@@ -70,7 +70,7 @@ else:
|
||||
None, None,
|
||||
None, None,
|
||||
None, None, None,
|
||||
None, None, None,
|
||||
None, None, None, None,
|
||||
None, None,
|
||||
None, None,
|
||||
None, None)
|
||||
@@ -93,6 +93,7 @@ MODEL_CLASSES = {
|
||||
'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'ctrl': (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
'albert': (AlbertConfig, TFAlbertForMaskedLM, AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
't5': (T5Config, TFT5WithLMHeadModel, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||
|
||||
@@ -184,7 +184,9 @@ class TFPreTrainedModel(tf.keras.Model):
|
||||
model_args: (`optional`) Sequence of positional arguments:
|
||||
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
|
||||
|
||||
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
|
||||
config: (`optional`) one of:
|
||||
- an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
|
||||
- a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`
|
||||
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
|
||||
|
||||
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
|
||||
@@ -236,10 +238,11 @@ class TFPreTrainedModel(tf.keras.Model):
|
||||
proxies = kwargs.pop('proxies', None)
|
||||
output_loading_info = kwargs.pop('output_loading_info', False)
|
||||
|
||||
# Load config
|
||||
if config is None:
|
||||
# Load config if we don't provide a configuration
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config_path = config if config is not None else pretrained_model_name_or_path
|
||||
config, model_kwargs = cls.config_class.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args,
|
||||
config_path, *model_args,
|
||||
cache_dir=cache_dir, return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
@@ -310,7 +313,11 @@ class TFPreTrainedModel(tf.keras.Model):
|
||||
assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
|
||||
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
||||
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
|
||||
model.load_weights(resolved_archive_file, by_name=True)
|
||||
try:
|
||||
model.load_weights(resolved_archive_file, by_name=True)
|
||||
except OSError:
|
||||
raise OSError("Unable to load weights from h5 file. "
|
||||
"If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. ")
|
||||
|
||||
ret = model(model.dummy_inputs, training=False) # Make sure restore ops are run
|
||||
|
||||
|
||||
@@ -281,7 +281,9 @@ class PreTrainedModel(nn.Module):
|
||||
model_args: (`optional`) Sequence of positional arguments:
|
||||
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
|
||||
|
||||
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
|
||||
config: (`optional`) one of:
|
||||
- an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
|
||||
- a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`
|
||||
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
|
||||
|
||||
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
|
||||
@@ -336,10 +338,11 @@ class PreTrainedModel(nn.Module):
|
||||
proxies = kwargs.pop('proxies', None)
|
||||
output_loading_info = kwargs.pop('output_loading_info', False)
|
||||
|
||||
# Load config
|
||||
if config is None:
|
||||
# Load config if we don't provide a configuration
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
config_path = config if config is not None else pretrained_model_name_or_path
|
||||
config, model_kwargs = cls.config_class.from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args,
|
||||
config_path, *model_args,
|
||||
cache_dir=cache_dir, return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
@@ -408,7 +411,11 @@ class PreTrainedModel(nn.Module):
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
||||
if state_dict is None and not from_tf:
|
||||
state_dict = torch.load(resolved_archive_file, map_location='cpu')
|
||||
try:
|
||||
state_dict = torch.load(resolved_archive_file, map_location='cpu')
|
||||
except:
|
||||
raise OSError("Unable to load weights from pytorch checkpoint file. "
|
||||
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. ")
|
||||
|
||||
missing_keys = []
|
||||
unexpected_keys = []
|
||||
|
||||
Reference in New Issue
Block a user