Adopt framework-specific blocks for content (#16342)

*  refactor code samples with framework-specific blocks

*  update training.mdx

* 🖍 apply feedback
This commit is contained in:
Steven Liu
2022-03-22 14:14:58 -07:00
committed by GitHub
parent 62cbd8423b
commit 7732148124
13 changed files with 169 additions and 133 deletions

View File

@@ -75,25 +75,29 @@ To ensure your model can be used by someone working with a different framework,
Converting a checkpoint for another framework is easy. Make sure you have PyTorch and TensorFlow installed (see [here](installation) for installation instructions), and then find the specific model for your task in the other framework.
For example, suppose you trained DistilBert for sequence classification in PyTorch and want to convert it to it's TensorFlow equivalent. Load the TensorFlow equivalent of your model for your task, and specify `from_pt=True` so 🤗 Transformers will convert the PyTorch checkpoint to a TensorFlow checkpoint:
```py
>>> tf_model = TFDistilBertForSequenceClassification.from_pretrained("path/to/awesome-name-you-picked", from_pt=True)
```
Then save your new TensorFlow model with it's new checkpoint:
```py
>>> tf_model.save_pretrained("path/to/awesome-name-you-picked")
```
Similarly, specify `from_tf=True` to convert a checkpoint from TensorFlow to PyTorch:
<frameworkcontent>
<pt>
Specify `from_tf=True` to convert a checkpoint from TensorFlow to PyTorch:
```py
>>> pt_model = DistilBertForSequenceClassification.from_pretrained("path/to/awesome-name-you-picked", from_tf=True)
>>> pt_model.save_pretrained("path/to/awesome-name-you-picked")
```
</pt>
<tf>
Specify `from_pt=True` to convert a checkpoint from PyTorch to TensorFlow:
```py
>>> tf_model = TFDistilBertForSequenceClassification.from_pretrained("path/to/awesome-name-you-picked", from_pt=True)
```
Then you can save your new TensorFlow model with it's new checkpoint:
```py
>>> tf_model.save_pretrained("path/to/awesome-name-you-picked")
```
</tf>
<jax>
If a model is available in Flax, you can also convert a checkpoint from PyTorch to Flax:
```py
@@ -101,9 +105,13 @@ If a model is available in Flax, you can also convert a checkpoint from PyTorch
... "path/to/awesome-name-you-picked", from_pt=True
... )
```
</jax>
</frameworkcontent>
## Push a model with `Trainer`
## Push a model during training
<frameworkcontent>
<pt>
<Youtube id="Z1-XMy-GNLQ"/>
Sharing a model to the Hub is as simple as adding an extra parameter or callback. Remember from the [fine-tuning tutorial](training), the [`TrainingArguments`] class is where you specify hyperparameters and additional training options. One of these training options includes the ability to push a model directly to the Hub. Set `push_to_hub=True` in your [`TrainingArguments`]:
@@ -129,10 +137,9 @@ After you fine-tune your model, call [`~transformers.Trainer.push_to_hub`] on [`
```py
>>> trainer.push_to_hub()
```
## Push a model with `PushToHubCallback`
TensorFlow users can enable the same functionality with [`PushToHubCallback`]. In the [`PushToHubCallback`] function, add:
</pt>
<tf>
Share a model to the Hub with [`PushToHubCallback`]. In the [`PushToHubCallback`] function, add:
- An output directory for your model.
- A tokenizer.
@@ -151,6 +158,8 @@ Add the callback to [`fit`](https://keras.io/api/models/model_training_apis/), a
```py
>>> model.fit(tf_train_dataset, validation_data=tf_validation_dataset, epochs=3, callbacks=push_to_hub_callback)
```
</tf>
</frameworkcontent>
## Use the `push_to_hub` function