Add ALBERT to the Tensorflow to Pytorch model conversion cli (#3933)
* Add ALBERT to convert command of transformers-cli * Document ALBERT tf to pytorch model conversion
This commit is contained in:
@@ -12,7 +12,7 @@ A command-line interface is provided to convert original Bert/GPT/GPT-2/Transfor
|
|||||||
BERT
|
BERT
|
||||||
^^^^
|
^^^^
|
||||||
|
|
||||||
You can convert any TensorFlow checkpoint for BERT (in particular `the pre-trained models released by Google <https://github.com/google-research/bert#pre-trained-models>`_\ ) in a PyTorch save file by using the `convert_tf_checkpoint_to_pytorch.py <https://github.com/huggingface/transformers/blob/master/transformers/convert_tf_checkpoint_to_pytorch.py>`_ script.
|
You can convert any TensorFlow checkpoint for BERT (in particular `the pre-trained models released by Google <https://github.com/google-research/bert#pre-trained-models>`_\ ) in a PyTorch save file by using the `convert_bert_original_tf_checkpoint_to_pytorch.py <https://github.com/huggingface/transformers/blob/master/src/transformers/convert_bert_original_tf_checkpoint_to_pytorch.py>`_ script.
|
||||||
|
|
||||||
This CLI takes as input a TensorFlow checkpoint (three files starting with ``bert_model.ckpt``\ ) and the associated configuration file (\ ``bert_config.json``\ ), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using ``torch.load()`` (see examples in `run_bert_extract_features.py <https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/examples/run_bert_extract_features.py>`_\ , `run_bert_classifier.py <https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/examples/run_bert_classifier.py>`_ and `run_bert_squad.py <https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/examples/run_bert_squad.py>`_\ ).
|
This CLI takes as input a TensorFlow checkpoint (three files starting with ``bert_model.ckpt``\ ) and the associated configuration file (\ ``bert_config.json``\ ), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using ``torch.load()`` (see examples in `run_bert_extract_features.py <https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/examples/run_bert_extract_features.py>`_\ , `run_bert_classifier.py <https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/examples/run_bert_classifier.py>`_ and `run_bert_squad.py <https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/examples/run_bert_squad.py>`_\ ).
|
||||||
|
|
||||||
@@ -33,6 +33,26 @@ Here is an example of the conversion process for a pre-trained ``BERT-Base Uncas
|
|||||||
|
|
||||||
You can download Google's pre-trained models for the conversion `here <https://github.com/google-research/bert#pre-trained-models>`__.
|
You can download Google's pre-trained models for the conversion `here <https://github.com/google-research/bert#pre-trained-models>`__.
|
||||||
|
|
||||||
|
ALBERT
|
||||||
|
^^^^^^
|
||||||
|
|
||||||
|
Convert TensorFlow model checkpoints of ALBERT to PyTorch using the `convert_albert_original_tf_checkpoint_to_pytorch.py <https://github.com/huggingface/transformers/blob/master/src/transformers/convert_bert_original_tf_checkpoint_to_pytorch.py>`_ script.
|
||||||
|
|
||||||
|
The CLI takes as input a TensorFlow checkpoint (three files starting with ``model.ckpt-best``\ ) and the accompanying configuration file (\ ``albert_config.json``\ ), then creates and saves a PyTorch model. To run this conversion you will need to have TensorFlow and PyTorch installed.
|
||||||
|
|
||||||
|
Here is an example of the conversion process for the pre-trained ``ALBERT Base`` model:
|
||||||
|
|
||||||
|
.. code-block:: shell
|
||||||
|
|
||||||
|
export ALBERT_BASE_DIR=/path/to/albert/albert_base
|
||||||
|
|
||||||
|
transformers-cli convert --model_type albert \
|
||||||
|
--tf_checkpoint $ALBERT_BASE_DIR/model.ckpt-best \
|
||||||
|
--config $ALBERT_BASE_DIR/albert_config.json \
|
||||||
|
--pytorch_dump_output $ALBERT_BASE_DIR/pytorch_model.bin
|
||||||
|
|
||||||
|
You can download Google's pre-trained models for the conversion `here <https://github.com/google-research/albert#pre-trained-models>`__.
|
||||||
|
|
||||||
OpenAI GPT
|
OpenAI GPT
|
||||||
^^^^^^^^^^
|
^^^^^^^^^^
|
||||||
|
|
||||||
|
|||||||
@@ -62,7 +62,21 @@ class ConvertCommand(BaseTransformersCLICommand):
|
|||||||
self._finetuning_task_name = finetuning_task_name
|
self._finetuning_task_name = finetuning_task_name
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
if self._model_type == "bert":
|
if self._model_type == "albert":
|
||||||
|
try:
|
||||||
|
from transformers.convert_albert_original_tf_checkpoint_to_pytorch import (
|
||||||
|
convert_tf_checkpoint_to_pytorch,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
msg = (
|
||||||
|
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||||
|
"In that case, it requires TensorFlow to be installed. Please see "
|
||||||
|
"https://www.tensorflow.org/install/ for installation instructions."
|
||||||
|
)
|
||||||
|
raise ImportError(msg)
|
||||||
|
|
||||||
|
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||||
|
elif self._model_type == "bert":
|
||||||
try:
|
try:
|
||||||
from transformers.convert_bert_original_tf_checkpoint_to_pytorch import (
|
from transformers.convert_bert_original_tf_checkpoint_to_pytorch import (
|
||||||
convert_tf_checkpoint_to_pytorch,
|
convert_tf_checkpoint_to_pytorch,
|
||||||
|
|||||||
Reference in New Issue
Block a user