This commit is contained in:
Lysandre Debut
2021-11-26 04:35:08 -05:00
committed by GitHub
parent c15f4f203f
commit 2318bf77eb

View File

@@ -346,35 +346,43 @@ Once your model is fine-tuned, you can save it with its tokenizer in the followi
.. code-block::
>>> ## PYTORCH CODE
>>> pt_save_directory = './pt_save_pretrained'
>>> tokenizer.save_pretrained(pt_save_directory)
>>> pt_model.save_pretrained(pt_save_directory)
.. code-block::
>>> ## TENSORFLOW CODE
>>> tf_save_directory = './tf_save_pretrained'
>>> tokenizer.save_pretrained(tf_save_directory)
>>> tf_model.save_pretrained(tf_save_directory)
You can then load this model back using the :func:`~transformers.AutoModel.from_pretrained` method by passing the
directory name instead of the model name. One cool feature of 🤗 Transformers is that you can easily switch between
PyTorch and TensorFlow: any model saved as before can be loaded back either in PyTorch or TensorFlow. If you are
loading a saved PyTorch model in a TensorFlow model, use :func:`~transformers.TFAutoModel.from_pretrained` like this:
PyTorch and TensorFlow: any model saved as before can be loaded back either in PyTorch or TensorFlow.
If you would like to load your saved model in the other framework, first make sure it is installed:
.. code-block:: bash
## PYTORCH CODE
pip install tensorflow
## TENSORFLOW CODE
pip install torch
Then, use the corresponding Auto class to load it like this:
.. code-block::
## PYTORCH CODE
>>> from transformers import TFAutoModel
>>> tokenizer = AutoTokenizer.from_pretrained(pt_save_directory)
>>> tf_model = TFAutoModel.from_pretrained(pt_save_directory, from_pt=True)
and if you are loading a saved TensorFlow model in a PyTorch model, you should use the following code:
.. code-block::
## TENSORFLOW CODE
>>> from transformers import AutoModel
>>> tokenizer = AutoTokenizer.from_pretrained(tf_save_directory)
>>> pt_model = AutoModel.from_pretrained(tf_save_directory, from_tf=True)
Lastly, you can also ask the model to return all hidden states and all attention weights if you need them: