Add t5 convert to transformers-cli (#9654)
* Update run_mlm.py * add t5 model to transformers-cli convert * update rum_mlm.py same as master * update converting model docs * update converting model docs * Update convert.py * Trigger notification * update import sorted * fix typo t5
This commit is contained in:
@@ -168,3 +168,18 @@ Here is an example of the conversion process for a pre-trained XLM model:
|
|||||||
--pytorch_dump_output $PYTORCH_DUMP_OUTPUT
|
--pytorch_dump_output $PYTORCH_DUMP_OUTPUT
|
||||||
[--config XML_CONFIG] \
|
[--config XML_CONFIG] \
|
||||||
[--finetuning_task_name XML_FINETUNED_TASK]
|
[--finetuning_task_name XML_FINETUNED_TASK]
|
||||||
|
|
||||||
|
|
||||||
|
T5
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Here is an example of the conversion process for a pre-trained T5 model:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
export T5=/path/to/t5/uncased_L-12_H-768_A-12
|
||||||
|
|
||||||
|
transformers-cli convert --model_type t5 \
|
||||||
|
--tf_checkpoint $T5/t5_model.ckpt \
|
||||||
|
--config $T5/t5_config.json \
|
||||||
|
--pytorch_dump_output $T5/pytorch_model.bin
|
||||||
|
|||||||
@@ -110,6 +110,13 @@ class ConvertCommand(BaseTransformersCLICommand):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(IMPORT_ERROR_MESSAGE)
|
raise ImportError(IMPORT_ERROR_MESSAGE)
|
||||||
|
|
||||||
|
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||||
|
elif self._model_type == "t5":
|
||||||
|
try:
|
||||||
|
from ..models.t5.convert_t5_original_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(IMPORT_ERROR_MESSAGE)
|
||||||
|
|
||||||
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||||
elif self._model_type == "gpt":
|
elif self._model_type == "gpt":
|
||||||
from ..models.openai.convert_openai_original_tf_checkpoint_to_pytorch import (
|
from ..models.openai.convert_openai_original_tf_checkpoint_to_pytorch import (
|
||||||
@@ -168,5 +175,5 @@ class ConvertCommand(BaseTransformersCLICommand):
|
|||||||
convert_lxmert_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
|
convert_lxmert_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"--model_type should be selected in the list [bert, gpt, gpt2, transfo_xl, xlnet, xlm, lxmert]"
|
"--model_type should be selected in the list [bert, gpt, gpt2, t5, transfo_xl, xlnet, xlm, lxmert]"
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user