From 41e8291217ee0972be71b53bf4b43672ae861578 Mon Sep 17 00:00:00 2001 From: fgaim Date: Mon, 11 May 2020 19:10:00 +0200 Subject: [PATCH] Add ALBERT to the Tensorflow to Pytorch model conversion cli (#3933) * Add ALBERT to convert command of transformers-cli * Document ALBERT tf to pytorch model conversion --- docs/source/converting_tensorflow_models.rst | 22 +++++++++++++++++++- src/transformers/commands/convert.py | 16 +++++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/docs/source/converting_tensorflow_models.rst b/docs/source/converting_tensorflow_models.rst index b6b93571dd..4151f8cf5c 100644 --- a/docs/source/converting_tensorflow_models.rst +++ b/docs/source/converting_tensorflow_models.rst @@ -12,7 +12,7 @@ A command-line interface is provided to convert original Bert/GPT/GPT-2/Transfor BERT ^^^^ -You can convert any TensorFlow checkpoint for BERT (in particular `the pre-trained models released by Google `_\ ) in a PyTorch save file by using the `convert_tf_checkpoint_to_pytorch.py `_ script. +You can convert any TensorFlow checkpoint for BERT (in particular `the pre-trained models released by Google `_\ ) in a PyTorch save file by using the `convert_bert_original_tf_checkpoint_to_pytorch.py `_ script. This CLI takes as input a TensorFlow checkpoint (three files starting with ``bert_model.ckpt``\ ) and the associated configuration file (\ ``bert_config.json``\ ), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using ``torch.load()`` (see examples in `run_bert_extract_features.py `_\ , `run_bert_classifier.py `_ and `run_bert_squad.py `_\ ). @@ -33,6 +33,26 @@ Here is an example of the conversion process for a pre-trained ``BERT-Base Uncas You can download Google's pre-trained models for the conversion `here `__. +ALBERT +^^^^^^ + +Convert TensorFlow model checkpoints of ALBERT to PyTorch using the `convert_albert_original_tf_checkpoint_to_pytorch.py `_ script. + +The CLI takes as input a TensorFlow checkpoint (three files starting with ``model.ckpt-best``\ ) and the accompanying configuration file (\ ``albert_config.json``\ ), then creates and saves a PyTorch model. To run this conversion you will need to have TensorFlow and PyTorch installed. + +Here is an example of the conversion process for the pre-trained ``ALBERT Base`` model: + +.. code-block:: shell + + export ALBERT_BASE_DIR=/path/to/albert/albert_base + + transformers-cli convert --model_type albert \ + --tf_checkpoint $ALBERT_BASE_DIR/model.ckpt-best \ + --config $ALBERT_BASE_DIR/albert_config.json \ + --pytorch_dump_output $ALBERT_BASE_DIR/pytorch_model.bin + +You can download Google's pre-trained models for the conversion `here `__. + OpenAI GPT ^^^^^^^^^^ diff --git a/src/transformers/commands/convert.py b/src/transformers/commands/convert.py index a31ef53b62..96464e3f91 100644 --- a/src/transformers/commands/convert.py +++ b/src/transformers/commands/convert.py @@ -62,7 +62,21 @@ class ConvertCommand(BaseTransformersCLICommand): self._finetuning_task_name = finetuning_task_name def run(self): - if self._model_type == "bert": + if self._model_type == "albert": + try: + from transformers.convert_albert_original_tf_checkpoint_to_pytorch import ( + convert_tf_checkpoint_to_pytorch, + ) + except ImportError: + msg = ( + "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " + "In that case, it requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise ImportError(msg) + + convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output) + elif self._model_type == "bert": try: from transformers.convert_bert_original_tf_checkpoint_to_pytorch import ( convert_tf_checkpoint_to_pytorch,