Adopt framework-specific blocks for content (#16342)
* ✨ refactor code samples with framework-specific blocks * ✨ update training.mdx * 🖍 apply feedback
This commit is contained in:
@@ -63,8 +63,10 @@ If you like, you can create a smaller subset of the full dataset to fine-tune on
|
||||
|
||||
<a id='trainer'></a>
|
||||
|
||||
## Fine-tune with `Trainer`
|
||||
## Train
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
<Youtube id="nvBXf7s7vTI"/>
|
||||
|
||||
🤗 Transformers provides a [`Trainer`] class optimized for training 🤗 Transformers models, making it easier to start training without manually writing your own training loop. The [`Trainer`] API supports a wide range of training options and features such as logging, gradient accumulation, and mixed precision.
|
||||
@@ -143,14 +145,13 @@ Then fine-tune your model by calling [`~transformers.Trainer.train`]:
|
||||
```py
|
||||
>>> trainer.train()
|
||||
```
|
||||
|
||||
</pt>
|
||||
<tf>
|
||||
<a id='keras'></a>
|
||||
|
||||
## Fine-tune with Keras
|
||||
|
||||
<Youtube id="rnTGBy2ax1c"/>
|
||||
|
||||
🤗 Transformers models also supports training in TensorFlow with the Keras API. You only need to make a few changes before you can fine-tune.
|
||||
🤗 Transformers models also supports training in TensorFlow with the Keras API.
|
||||
|
||||
### Convert dataset to TensorFlow format
|
||||
|
||||
@@ -210,11 +211,15 @@ Then compile and fine-tune your model with [`fit`](https://keras.io/api/models/m
|
||||
|
||||
>>> model.fit(tf_train_dataset, validation_data=tf_validation_dataset, epochs=3)
|
||||
```
|
||||
</tf>
|
||||
</frameworkcontent>
|
||||
|
||||
<a id='pytorch_native'></a>
|
||||
|
||||
## Fine-tune in native PyTorch
|
||||
## Train in native PyTorch
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
<Youtube id="Dh9CL8fyG80"/>
|
||||
|
||||
[`Trainer`] takes care of the training loop and allows you to fine-tune a model in a single line of code. For users who prefer to write their own training loop, you can also fine-tune a 🤗 Transformers model in native PyTorch.
|
||||
@@ -354,6 +359,8 @@ Just like how you need to add an evaluation function to [`Trainer`], you need to
|
||||
|
||||
>>> metric.compute()
|
||||
```
|
||||
</pt>
|
||||
</frameworkcontent>
|
||||
|
||||
<a id='additional-resources'></a>
|
||||
|
||||
|
||||
Reference in New Issue
Block a user