From 8940c7662de51b78d9585e870cef1a7944830cc5 Mon Sep 17 00:00:00 2001 From: acul3 <56231298+acul3@users.noreply.github.com> Date: Wed, 20 Jan 2021 21:34:27 +0700 Subject: [PATCH] Add t5 convert to transformers-cli (#9654) * Update run_mlm.py * add t5 model to transformers-cli convert * update rum_mlm.py same as master * update converting model docs * update converting model docs * Update convert.py * Trigger notification * update import sorted * fix typo t5 --- docs/source/converting_tensorflow_models.rst | 15 +++++++++++++++ src/transformers/commands/convert.py | 9 ++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/docs/source/converting_tensorflow_models.rst b/docs/source/converting_tensorflow_models.rst index 9cc1333fde..cfc057bc82 100644 --- a/docs/source/converting_tensorflow_models.rst +++ b/docs/source/converting_tensorflow_models.rst @@ -168,3 +168,18 @@ Here is an example of the conversion process for a pre-trained XLM model: --pytorch_dump_output $PYTORCH_DUMP_OUTPUT [--config XML_CONFIG] \ [--finetuning_task_name XML_FINETUNED_TASK] + + +T5 +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Here is an example of the conversion process for a pre-trained T5 model: + +.. code-block:: shell + + export T5=/path/to/t5/uncased_L-12_H-768_A-12 + + transformers-cli convert --model_type t5 \ + --tf_checkpoint $T5/t5_model.ckpt \ + --config $T5/t5_config.json \ + --pytorch_dump_output $T5/pytorch_model.bin diff --git a/src/transformers/commands/convert.py b/src/transformers/commands/convert.py index 30767f26f9..6867cf6c01 100644 --- a/src/transformers/commands/convert.py +++ b/src/transformers/commands/convert.py @@ -110,6 +110,13 @@ class ConvertCommand(BaseTransformersCLICommand): except ImportError: raise ImportError(IMPORT_ERROR_MESSAGE) + convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output) + elif self._model_type == "t5": + try: + from ..models.t5.convert_t5_original_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch + except ImportError: + raise ImportError(IMPORT_ERROR_MESSAGE) + convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output) elif self._model_type == "gpt": from ..models.openai.convert_openai_original_tf_checkpoint_to_pytorch import ( @@ -168,5 +175,5 @@ class ConvertCommand(BaseTransformersCLICommand): convert_lxmert_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output) else: raise ValueError( - "--model_type should be selected in the list [bert, gpt, gpt2, transfo_xl, xlnet, xlm, lxmert]" + "--model_type should be selected in the list [bert, gpt, gpt2, t5, transfo_xl, xlnet, xlm, lxmert]" )