clean up PT <=> TF 2.0 conversion and config loading
This commit is contained in:
@@ -32,7 +32,7 @@ from transformers import (load_pytorch_checkpoint_in_tf2_model,
|
|||||||
TransfoXLConfig, TFTransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
TransfoXLConfig, TFTransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
OpenAIGPTConfig, TFOpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
RobertaConfig, TFRobertaForMaskedLM, TFRobertaForSequenceClassification, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
DistilBertConfig, TFDistilBertForMaskedLM, TFDistilBertForQuestionAnswering, TFDistilBertForSequenceClassification, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
CTRLConfig, TFCTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
CTRLConfig, TFCTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
AlbertConfig, TFAlbertForMaskedLM, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
AlbertConfig, TFAlbertForMaskedLM, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
T5Config, TFT5WithLMHeadModel, T5_PRETRAINED_CONFIG_ARCHIVE_MAP)
|
T5Config, TFT5WithLMHeadModel, T5_PRETRAINED_CONFIG_ARCHIVE_MAP)
|
||||||
@@ -47,7 +47,7 @@ if is_torch_available():
|
|||||||
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DistilBertForSequenceClassification, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP)
|
T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||||
@@ -59,7 +59,7 @@ else:
|
|||||||
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
TransfoXLLMHeadModel, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
OpenAIGPTLMHeadModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
RobertaForMaskedLM, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
DistilBertForMaskedLM, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
DistilBertForMaskedLM, DistilBertForSequenceClassification, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP) = (
|
T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP) = (
|
||||||
@@ -70,7 +70,7 @@ else:
|
|||||||
None, None,
|
None, None,
|
||||||
None, None,
|
None, None,
|
||||||
None, None, None,
|
None, None, None,
|
||||||
None, None, None,
|
None, None, None, None,
|
||||||
None, None,
|
None, None,
|
||||||
None, None,
|
None, None,
|
||||||
None, None)
|
None, None)
|
||||||
@@ -93,6 +93,7 @@ MODEL_CLASSES = {
|
|||||||
'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
'roberta-large-mnli': (RobertaConfig, TFRobertaForSequenceClassification, RobertaForSequenceClassification, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
'distilbert': (DistilBertConfig, TFDistilBertForMaskedLM, DistilBertForMaskedLM, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
|
'distilbert-base-uncased-distilled-squad': (DistilBertConfig, TFDistilBertForQuestionAnswering, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
'ctrl': (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
'ctrl': (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
'albert': (AlbertConfig, TFAlbertForMaskedLM, AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
'albert': (AlbertConfig, TFAlbertForMaskedLM, AlbertForMaskedLM, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
't5': (T5Config, TFT5WithLMHeadModel, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
't5': (T5Config, TFT5WithLMHeadModel, T5WithLMHeadModel, T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5_PRETRAINED_CONFIG_ARCHIVE_MAP),
|
||||||
|
|||||||
@@ -184,7 +184,9 @@ class TFPreTrainedModel(tf.keras.Model):
|
|||||||
model_args: (`optional`) Sequence of positional arguments:
|
model_args: (`optional`) Sequence of positional arguments:
|
||||||
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
|
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
|
||||||
|
|
||||||
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
|
config: (`optional`) one of:
|
||||||
|
- an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
|
||||||
|
- a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`
|
||||||
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
|
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
|
||||||
|
|
||||||
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
|
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
|
||||||
@@ -236,10 +238,11 @@ class TFPreTrainedModel(tf.keras.Model):
|
|||||||
proxies = kwargs.pop('proxies', None)
|
proxies = kwargs.pop('proxies', None)
|
||||||
output_loading_info = kwargs.pop('output_loading_info', False)
|
output_loading_info = kwargs.pop('output_loading_info', False)
|
||||||
|
|
||||||
# Load config
|
# Load config if we don't provide a configuration
|
||||||
if config is None:
|
if not isinstance(config, PretrainedConfig):
|
||||||
|
config_path = config if config is not None else pretrained_model_name_or_path
|
||||||
config, model_kwargs = cls.config_class.from_pretrained(
|
config, model_kwargs = cls.config_class.from_pretrained(
|
||||||
pretrained_model_name_or_path, *model_args,
|
config_path, *model_args,
|
||||||
cache_dir=cache_dir, return_unused_kwargs=True,
|
cache_dir=cache_dir, return_unused_kwargs=True,
|
||||||
force_download=force_download,
|
force_download=force_download,
|
||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
@@ -310,7 +313,11 @@ class TFPreTrainedModel(tf.keras.Model):
|
|||||||
assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
|
assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
|
||||||
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
||||||
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
|
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
|
||||||
|
try:
|
||||||
model.load_weights(resolved_archive_file, by_name=True)
|
model.load_weights(resolved_archive_file, by_name=True)
|
||||||
|
except OSError:
|
||||||
|
raise OSError("Unable to load weights from h5 file. "
|
||||||
|
"If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. ")
|
||||||
|
|
||||||
ret = model(model.dummy_inputs, training=False) # Make sure restore ops are run
|
ret = model(model.dummy_inputs, training=False) # Make sure restore ops are run
|
||||||
|
|
||||||
|
|||||||
@@ -281,7 +281,9 @@ class PreTrainedModel(nn.Module):
|
|||||||
model_args: (`optional`) Sequence of positional arguments:
|
model_args: (`optional`) Sequence of positional arguments:
|
||||||
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
|
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
|
||||||
|
|
||||||
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
|
config: (`optional`) one of:
|
||||||
|
- an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
|
||||||
|
- a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`
|
||||||
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
|
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
|
||||||
|
|
||||||
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
|
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
|
||||||
@@ -336,10 +338,11 @@ class PreTrainedModel(nn.Module):
|
|||||||
proxies = kwargs.pop('proxies', None)
|
proxies = kwargs.pop('proxies', None)
|
||||||
output_loading_info = kwargs.pop('output_loading_info', False)
|
output_loading_info = kwargs.pop('output_loading_info', False)
|
||||||
|
|
||||||
# Load config
|
# Load config if we don't provide a configuration
|
||||||
if config is None:
|
if not isinstance(config, PretrainedConfig):
|
||||||
|
config_path = config if config is not None else pretrained_model_name_or_path
|
||||||
config, model_kwargs = cls.config_class.from_pretrained(
|
config, model_kwargs = cls.config_class.from_pretrained(
|
||||||
pretrained_model_name_or_path, *model_args,
|
config_path, *model_args,
|
||||||
cache_dir=cache_dir, return_unused_kwargs=True,
|
cache_dir=cache_dir, return_unused_kwargs=True,
|
||||||
force_download=force_download,
|
force_download=force_download,
|
||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
@@ -408,7 +411,11 @@ class PreTrainedModel(nn.Module):
|
|||||||
model = cls(config, *model_args, **model_kwargs)
|
model = cls(config, *model_args, **model_kwargs)
|
||||||
|
|
||||||
if state_dict is None and not from_tf:
|
if state_dict is None and not from_tf:
|
||||||
|
try:
|
||||||
state_dict = torch.load(resolved_archive_file, map_location='cpu')
|
state_dict = torch.load(resolved_archive_file, map_location='cpu')
|
||||||
|
except:
|
||||||
|
raise OSError("Unable to load weights from pytorch checkpoint file. "
|
||||||
|
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. ")
|
||||||
|
|
||||||
missing_keys = []
|
missing_keys = []
|
||||||
unexpected_keys = []
|
unexpected_keys = []
|
||||||
|
|||||||
Reference in New Issue
Block a user