diff --git a/pytorch_transformers/__init__.py b/pytorch_transformers/__init__.py index a41c838015..f13457f073 100644 --- a/pytorch_transformers/__init__.py +++ b/pytorch_transformers/__init__.py @@ -96,8 +96,15 @@ if _tf_available: logger.info("TensorFlow version {} available.".format(tf.__version__)) from .modeling_tf_utils import TFPreTrainedModel + from .modeling_tf_auto import (TFAutoModel, TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering, + TFAutoModelWithLMHead) + from .modeling_tf_bert import (TFBertPreTrainedModel, TFBertModel, TFBertForPreTraining, - TFBertForMaskedLM, TFBertForNextSentencePrediction, load_bert_pt_weights_in_tf) + TFBertForMaskedLM, TFBertForNextSentencePrediction, + TFBertForSequenceClassification, TFBertForMultipleChoice, + TFBertForTokenClassification, TFBertForQuestionAnswering, + load_bert_pt_weights_in_tf, + TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP) # Files and general utilities diff --git a/pytorch_transformers/modeling_tf_auto.py b/pytorch_transformers/modeling_tf_auto.py new file mode 100644 index 0000000000..3b10d700a1 --- /dev/null +++ b/pytorch_transformers/modeling_tf_auto.py @@ -0,0 +1,488 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Auto Model class. """ + +from __future__ import absolute_import, division, print_function, unicode_literals + +import logging + +from .modeling_tf_bert import TFBertModel, TFBertForMaskedLM, TFBertForSequenceClassification, TFBertForQuestionAnswering + +from .file_utils import add_start_docstrings + +logger = logging.getLogger(__name__) + + +class TFAutoModel(object): + r""" + :class:`~pytorch_transformers.TFAutoModel` is a generic model class + that will be instantiated as one of the base model classes of the library + when created with the `TFAutoModel.from_pretrained(pretrained_model_name_or_path)` + class method. + + The `from_pretrained()` method takes care of returning the correct model class instance + using pattern matching on the `pretrained_model_name_or_path` string. + + The base model class to instantiate is selected as the first pattern matching + in the `pretrained_model_name_or_path` string (in the following order): + - contains `distilbert`: DistilBertModel (DistilBERT model) + - contains `roberta`: RobertaModel (RoBERTa model) + - contains `bert`: TFBertModel (Bert model) + - contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model) + - contains `gpt2`: GPT2Model (OpenAI GPT-2 model) + - contains `transfo-xl`: TransfoXLModel (Transformer-XL model) + - contains `xlnet`: XLNetModel (XLNet model) + - contains `xlm`: XLMModel (XLM model) + + This class cannot be instantiated using `__init__()` (throws an error). + """ + def __init__(self): + raise EnvironmentError("TFAutoModel is designed to be instantiated " + "using the `TFAutoModel.from_pretrained(pretrained_model_name_or_path)` method.") + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r""" Instantiates one of the base model classes of the library + from a pre-trained model configuration. + + The model class to instantiate is selected as the first pattern matching + in the `pretrained_model_name_or_path` string (in the following order): + - contains `distilbert`: DistilBertModel (DistilBERT model) + - contains `roberta`: RobertaModel (RoBERTa model) + - contains `bert`: TFBertModel (Bert model) + - contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model) + - contains `gpt2`: GPT2Model (OpenAI GPT-2 model) + - contains `transfo-xl`: TransfoXLModel (Transformer-XL model) + - contains `xlnet`: XLNetModel (XLNet model) + - contains `xlm`: XLMModel (XLM model) + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) + To train the model, you should first set it back in training mode with `model.train()` + + Params: + pretrained_model_name_or_path: either: + + - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a path to a `directory` containing model weights saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. + - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + + model_args: (`optional`) Sequence of positional arguments: + All remaning positional arguments will be passed to the underlying model's ``__init__`` method + + config: (`optional`) instance of a class derived from :class:`~pytorch_transformers.PretrainedConfig`: + Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: + + - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or + - the model was saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. + - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. + + state_dict: (`optional`) dict: + an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file. + This option can be used if you want to create a model from a pretrained configuration but load your own weights. + In this case though, you should check if using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and :func:`~pytorch_transformers.PreTrainedModel.from_pretrained` is not a simpler option. + + cache_dir: (`optional`) string: + Path to a directory in which a downloaded pre-trained model + configuration should be cached if the standard cache should not be used. + + force_download: (`optional`) boolean, default False: + Force to (re-)download the model weights and configuration files and override the cached versions if they exists. + + proxies: (`optional`) dict, default None: + A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. + The proxies are used on each request. + + output_loading_info: (`optional`) boolean: + Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. + + kwargs: (`optional`) Remaining dictionary of keyword arguments: + Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: + + - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) + - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~pytorch_transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. + + Examples:: + + model = TFAutoModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. + model = TFAutoModel.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` + model = TFAutoModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading + assert model.config.output_attention == True + # Loading from a TF checkpoint file instead of a PyTorch model (slower) + config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') + model = TFAutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) + + """ + if 'distilbert' in pretrained_model_name_or_path: + raise NotImplementedError + elif 'roberta' in pretrained_model_name_or_path: + raise NotImplementedError + elif 'bert' in pretrained_model_name_or_path: + return TFBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + elif 'openai-gpt' in pretrained_model_name_or_path: + raise NotImplementedError + elif 'gpt2' in pretrained_model_name_or_path: + raise NotImplementedError + elif 'transfo-xl' in pretrained_model_name_or_path: + raise NotImplementedError + elif 'xlnet' in pretrained_model_name_or_path: + raise NotImplementedError + elif 'xlm' in pretrained_model_name_or_path: + raise NotImplementedError + + raise ValueError("Unrecognized model identifier in {}. Should contains one of " + "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " + "'xlm', 'roberta'".format(pretrained_model_name_or_path)) + + +class TFAutoModelWithLMHead(object): + r""" + :class:`~pytorch_transformers.TFAutoModelWithLMHead` is a generic model class + that will be instantiated as one of the language modeling model classes of the library + when created with the `TFAutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` + class method. + + The `from_pretrained()` method takes care of returning the correct model class instance + using pattern matching on the `pretrained_model_name_or_path` string. + + The model class to instantiate is selected as the first pattern matching + in the `pretrained_model_name_or_path` string (in the following order): + - contains `distilbert`: DistilBertForMaskedLM (DistilBERT model) + - contains `roberta`: RobertaForMaskedLM (RoBERTa model) + - contains `bert`: TFBertForMaskedLM (Bert model) + - contains `openai-gpt`: OpenAIGPTLMHeadModel (OpenAI GPT model) + - contains `gpt2`: GPT2LMHeadModel (OpenAI GPT-2 model) + - contains `transfo-xl`: TransfoXLLMHeadModel (Transformer-XL model) + - contains `xlnet`: XLNetLMHeadModel (XLNet model) + - contains `xlm`: XLMWithLMHeadModel (XLM model) + + This class cannot be instantiated using `__init__()` (throws an error). + """ + def __init__(self): + raise EnvironmentError("TFAutoModelWithLMHead is designed to be instantiated " + "using the `TFAutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` method.") + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r""" Instantiates one of the language modeling model classes of the library + from a pre-trained model configuration. + + The `from_pretrained()` method takes care of returning the correct model class instance + using pattern matching on the `pretrained_model_name_or_path` string. + + The model class to instantiate is selected as the first pattern matching + in the `pretrained_model_name_or_path` string (in the following order): + - contains `distilbert`: DistilBertForMaskedLM (DistilBERT model) + - contains `roberta`: RobertaForMaskedLM (RoBERTa model) + - contains `bert`: TFBertForMaskedLM (Bert model) + - contains `openai-gpt`: OpenAIGPTLMHeadModel (OpenAI GPT model) + - contains `gpt2`: GPT2LMHeadModel (OpenAI GPT-2 model) + - contains `transfo-xl`: TransfoXLLMHeadModel (Transformer-XL model) + - contains `xlnet`: XLNetLMHeadModel (XLNet model) + - contains `xlm`: XLMWithLMHeadModel (XLM model) + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) + To train the model, you should first set it back in training mode with `model.train()` + + Params: + pretrained_model_name_or_path: either: + + - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a path to a `directory` containing model weights saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. + - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + + model_args: (`optional`) Sequence of positional arguments: + All remaning positional arguments will be passed to the underlying model's ``__init__`` method + + config: (`optional`) instance of a class derived from :class:`~pytorch_transformers.PretrainedConfig`: + Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: + + - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or + - the model was saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. + - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. + + state_dict: (`optional`) dict: + an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file. + This option can be used if you want to create a model from a pretrained configuration but load your own weights. + In this case though, you should check if using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and :func:`~pytorch_transformers.PreTrainedModel.from_pretrained` is not a simpler option. + + cache_dir: (`optional`) string: + Path to a directory in which a downloaded pre-trained model + configuration should be cached if the standard cache should not be used. + + force_download: (`optional`) boolean, default False: + Force to (re-)download the model weights and configuration files and override the cached versions if they exists. + + proxies: (`optional`) dict, default None: + A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. + The proxies are used on each request. + + output_loading_info: (`optional`) boolean: + Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. + + kwargs: (`optional`) Remaining dictionary of keyword arguments: + Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: + + - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) + - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~pytorch_transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. + + Examples:: + + model = TFAutoModelWithLMHead.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. + model = TFAutoModelWithLMHead.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` + model = TFAutoModelWithLMHead.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading + assert model.config.output_attention == True + # Loading from a TF checkpoint file instead of a PyTorch model (slower) + config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') + model = TFAutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) + + """ + if 'distilbert' in pretrained_model_name_or_path: + raise NotImplementedError + elif 'roberta' in pretrained_model_name_or_path: + raise NotImplementedError + elif 'bert' in pretrained_model_name_or_path: + return TFBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + elif 'openai-gpt' in pretrained_model_name_or_path: + raise NotImplementedError + elif 'gpt2' in pretrained_model_name_or_path: + raise NotImplementedError + elif 'transfo-xl' in pretrained_model_name_or_path: + raise NotImplementedError + elif 'xlnet' in pretrained_model_name_or_path: + raise NotImplementedError + elif 'xlm' in pretrained_model_name_or_path: + raise NotImplementedError + + raise ValueError("Unrecognized model identifier in {}. Should contains one of " + "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " + "'xlm', 'roberta'".format(pretrained_model_name_or_path)) + + +class TFAutoModelForSequenceClassification(object): + r""" + :class:`~pytorch_transformers.TFAutoModelForSequenceClassification` is a generic model class + that will be instantiated as one of the sequence classification model classes of the library + when created with the `TFAutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)` + class method. + + The `from_pretrained()` method takes care of returning the correct model class instance + using pattern matching on the `pretrained_model_name_or_path` string. + + The model class to instantiate is selected as the first pattern matching + in the `pretrained_model_name_or_path` string (in the following order): + - contains `distilbert`: DistilBertForSequenceClassification (DistilBERT model) + - contains `roberta`: RobertaForSequenceClassification (RoBERTa model) + - contains `bert`: TFBertForSequenceClassification (Bert model) + - contains `xlnet`: XLNetForSequenceClassification (XLNet model) + - contains `xlm`: XLMForSequenceClassification (XLM model) + + This class cannot be instantiated using `__init__()` (throws an error). + """ + def __init__(self): + raise EnvironmentError("TFAutoModelWithLMHead is designed to be instantiated " + "using the `TFAutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` method.") + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r""" Instantiates one of the sequence classification model classes of the library + from a pre-trained model configuration. + + The `from_pretrained()` method takes care of returning the correct model class instance + using pattern matching on the `pretrained_model_name_or_path` string. + + The model class to instantiate is selected as the first pattern matching + in the `pretrained_model_name_or_path` string (in the following order): + - contains `distilbert`: DistilBertForSequenceClassification (DistilBERT model) + - contains `roberta`: RobertaForSequenceClassification (RoBERTa model) + - contains `bert`: TFBertForSequenceClassification (Bert model) + - contains `xlnet`: XLNetForSequenceClassification (XLNet model) + - contains `xlm`: XLMForSequenceClassification (XLM model) + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) + To train the model, you should first set it back in training mode with `model.train()` + + Params: + pretrained_model_name_or_path: either: + + - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a path to a `directory` containing model weights saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. + - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + + model_args: (`optional`) Sequence of positional arguments: + All remaning positional arguments will be passed to the underlying model's ``__init__`` method + + config: (`optional`) instance of a class derived from :class:`~pytorch_transformers.PretrainedConfig`: + Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: + + - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or + - the model was saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. + - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. + + state_dict: (`optional`) dict: + an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file. + This option can be used if you want to create a model from a pretrained configuration but load your own weights. + In this case though, you should check if using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and :func:`~pytorch_transformers.PreTrainedModel.from_pretrained` is not a simpler option. + + cache_dir: (`optional`) string: + Path to a directory in which a downloaded pre-trained model + configuration should be cached if the standard cache should not be used. + + force_download: (`optional`) boolean, default False: + Force to (re-)download the model weights and configuration files and override the cached versions if they exists. + + proxies: (`optional`) dict, default None: + A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. + The proxies are used on each request. + + output_loading_info: (`optional`) boolean: + Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. + + kwargs: (`optional`) Remaining dictionary of keyword arguments: + Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: + + - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) + - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~pytorch_transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. + + Examples:: + + model = TFAutoModelForSequenceClassification.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. + model = TFAutoModelForSequenceClassification.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` + model = TFAutoModelForSequenceClassification.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading + assert model.config.output_attention == True + # Loading from a TF checkpoint file instead of a PyTorch model (slower) + config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') + model = TFAutoModelForSequenceClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) + + """ + if 'distilbert' in pretrained_model_name_or_path: + raise NotImplementedError + elif 'roberta' in pretrained_model_name_or_path: + raise NotImplementedError + elif 'bert' in pretrained_model_name_or_path: + return TFBertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + elif 'xlnet' in pretrained_model_name_or_path: + raise NotImplementedError + elif 'xlm' in pretrained_model_name_or_path: + raise NotImplementedError + + raise ValueError("Unrecognized model identifier in {}. Should contains one of " + "'bert', 'xlnet', 'xlm', 'roberta'".format(pretrained_model_name_or_path)) + + +class TFAutoModelForQuestionAnswering(object): + r""" + :class:`~pytorch_transformers.TFAutoModelForQuestionAnswering` is a generic model class + that will be instantiated as one of the question answering model classes of the library + when created with the `TFAutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path)` + class method. + + The `from_pretrained()` method takes care of returning the correct model class instance + using pattern matching on the `pretrained_model_name_or_path` string. + + The model class to instantiate is selected as the first pattern matching + in the `pretrained_model_name_or_path` string (in the following order): + - contains `distilbert`: DistilBertForQuestionAnswering (DistilBERT model) + - contains `bert`: TFBertForQuestionAnswering (Bert model) + - contains `xlnet`: XLNetForQuestionAnswering (XLNet model) + - contains `xlm`: XLMForQuestionAnswering (XLM model) + + This class cannot be instantiated using `__init__()` (throws an error). + """ + def __init__(self): + raise EnvironmentError("TFAutoModelWithLMHead is designed to be instantiated " + "using the `TFAutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` method.") + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r""" Instantiates one of the question answering model classes of the library + from a pre-trained model configuration. + + The `from_pretrained()` method takes care of returning the correct model class instance + using pattern matching on the `pretrained_model_name_or_path` string. + + The model class to instantiate is selected as the first pattern matching + in the `pretrained_model_name_or_path` string (in the following order): + - contains `distilbert`: DistilBertForQuestionAnswering (DistilBERT model) + - contains `bert`: TFBertForQuestionAnswering (Bert model) + - contains `xlnet`: XLNetForQuestionAnswering (XLNet model) + - contains `xlm`: XLMForQuestionAnswering (XLM model) + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) + To train the model, you should first set it back in training mode with `model.train()` + + Params: + pretrained_model_name_or_path: either: + + - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. + - a path to a `directory` containing model weights saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. + - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + + model_args: (`optional`) Sequence of positional arguments: + All remaning positional arguments will be passed to the underlying model's ``__init__`` method + + config: (`optional`) instance of a class derived from :class:`~pytorch_transformers.PretrainedConfig`: + Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: + + - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or + - the model was saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. + - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. + + state_dict: (`optional`) dict: + an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file. + This option can be used if you want to create a model from a pretrained configuration but load your own weights. + In this case though, you should check if using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and :func:`~pytorch_transformers.PreTrainedModel.from_pretrained` is not a simpler option. + + cache_dir: (`optional`) string: + Path to a directory in which a downloaded pre-trained model + configuration should be cached if the standard cache should not be used. + + force_download: (`optional`) boolean, default False: + Force to (re-)download the model weights and configuration files and override the cached versions if they exists. + + proxies: (`optional`) dict, default None: + A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. + The proxies are used on each request. + + output_loading_info: (`optional`) boolean: + Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. + + kwargs: (`optional`) Remaining dictionary of keyword arguments: + Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: + + - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) + - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~pytorch_transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. + + Examples:: + + model = TFAutoModelForQuestionAnswering.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. + model = TFAutoModelForQuestionAnswering.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` + model = TFAutoModelForQuestionAnswering.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading + assert model.config.output_attention == True + # Loading from a TF checkpoint file instead of a PyTorch model (slower) + config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') + model = TFAutoModelForQuestionAnswering.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) + + """ + if 'distilbert' in pretrained_model_name_or_path: + raise NotImplementedError + elif 'bert' in pretrained_model_name_or_path: + return TFBertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + elif 'xlnet' in pretrained_model_name_or_path: + raise NotImplementedError + elif 'xlm' in pretrained_model_name_or_path: + raise NotImplementedError + + raise ValueError("Unrecognized model identifier in {}. Should contains one of " + "'bert', 'xlnet', 'xlm'".format(pretrained_model_name_or_path)) diff --git a/pytorch_transformers/modeling_tf_utils.py b/pytorch_transformers/modeling_tf_utils.py index cb1588e399..ef2375c60a 100644 --- a/pytorch_transformers/modeling_tf_utils.py +++ b/pytorch_transformers/modeling_tf_utils.py @@ -170,9 +170,6 @@ class TFPreTrainedModel(tf.keras.Model): A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. The proxies are used on each request. - output_loading_info: (`optional`) boolean: - Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. - kwargs: (`optional`) Remaining dictionary of keyword arguments: Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: @@ -195,7 +192,6 @@ class TFPreTrainedModel(tf.keras.Model): from_pt = kwargs.pop('from_pt', False) force_download = kwargs.pop('force_download', False) proxies = kwargs.pop('proxies', None) - output_loading_info = kwargs.pop('output_loading_info', False) # Load config if config is None: @@ -258,11 +254,4 @@ class TFPreTrainedModel(tf.keras.Model): ret = model(inputs, training=False) # Make sure restore ops are run - # if hasattr(model, 'tie_weights'): - # model.tie_weights() # TODO make sure word embedding weights are still tied - - if output_loading_info: - loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs} - return model, loading_info - return model diff --git a/pytorch_transformers/tests/modeling_tf_auto_test.py b/pytorch_transformers/tests/modeling_tf_auto_test.py new file mode 100644 index 0000000000..816d6c1b1a --- /dev/null +++ b/pytorch_transformers/tests/modeling_tf_auto_test.py @@ -0,0 +1,85 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest +import shutil +import pytest +import logging + +try: + from pytorch_transformers import (AutoConfig, BertConfig, + TFAutoModel, TFBertModel, + TFAutoModelWithLMHead, TFBertForMaskedLM, + TFAutoModelForSequenceClassification, TFBertForSequenceClassification, + TFAutoModelForQuestionAnswering, TFBertForQuestionAnswering) + from pytorch_transformers.modeling_tf_bert import TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP + + from .modeling_common_test import (CommonTestCases, ids_tensor) + from .configuration_common_test import ConfigTester +except ImportError: + pytestmark = pytest.mark.skip("Require TensorFlow") + + +class TFAutoModelTest(unittest.TestCase): + def test_model_from_pretrained(self): + logging.basicConfig(level=logging.INFO) + for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + config = AutoConfig.from_pretrained(model_name) + self.assertIsNotNone(config) + self.assertIsInstance(config, BertConfig) + + model = TFAutoModel.from_pretrained(model_name) + self.assertIsNotNone(model) + self.assertIsInstance(model, TFBertModel) + + def test_lmhead_model_from_pretrained(self): + logging.basicConfig(level=logging.INFO) + for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + config = AutoConfig.from_pretrained(model_name) + self.assertIsNotNone(config) + self.assertIsInstance(config, BertConfig) + + model = TFAutoModelWithLMHead.from_pretrained(model_name) + self.assertIsNotNone(model) + self.assertIsInstance(model, TFBertForMaskedLM) + + def test_sequence_classification_model_from_pretrained(self): + logging.basicConfig(level=logging.INFO) + for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + config = AutoConfig.from_pretrained(model_name) + self.assertIsNotNone(config) + self.assertIsInstance(config, BertConfig) + + model = TFAutoModelForSequenceClassification.from_pretrained(model_name) + self.assertIsNotNone(model) + self.assertIsInstance(model, TFBertForSequenceClassification) + + def test_question_answering_model_from_pretrained(self): + logging.basicConfig(level=logging.INFO) + for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + config = AutoConfig.from_pretrained(model_name) + self.assertIsNotNone(config) + self.assertIsInstance(config, BertConfig) + + model = TFAutoModelForQuestionAnswering.from_pretrained(model_name) + self.assertIsNotNone(model) + self.assertIsInstance(model, TFBertForQuestionAnswering) + + +if __name__ == "__main__": + unittest.main()