Add t5 convert to transformers-cli (#9654)

* Update run_mlm.py

* add t5 model to transformers-cli convert

* update rum_mlm.py same as master

* update converting model docs

* update converting model docs

* Update convert.py

* Trigger notification

* update import sorted

* fix typo t5
This commit is contained in:
acul3
2021-01-20 21:34:27 +07:00
committed by GitHub
parent 7251a4736d
commit 8940c7662d
2 changed files with 23 additions and 1 deletions

View File

@@ -110,6 +110,13 @@ class ConvertCommand(BaseTransformersCLICommand):
except ImportError:
raise ImportError(IMPORT_ERROR_MESSAGE)
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "t5":
try:
from ..models.t5.convert_t5_original_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
except ImportError:
raise ImportError(IMPORT_ERROR_MESSAGE)
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "gpt":
from ..models.openai.convert_openai_original_tf_checkpoint_to_pytorch import (
@@ -168,5 +175,5 @@ class ConvertCommand(BaseTransformersCLICommand):
convert_lxmert_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
else:
raise ValueError(
"--model_type should be selected in the list [bert, gpt, gpt2, transfo_xl, xlnet, xlm, lxmert]"
"--model_type should be selected in the list [bert, gpt, gpt2, t5, transfo_xl, xlnet, xlm, lxmert]"
)