Adopt framework-specific blocks for content (#16342)

*  refactor code samples with framework-specific blocks

*  update training.mdx

* 🖍 apply feedback
This commit is contained in:
Steven Liu
2022-03-22 14:14:58 -07:00
committed by GitHub
parent 62cbd8423b
commit 7732148124
13 changed files with 169 additions and 133 deletions

View File

@@ -110,8 +110,10 @@ Use [`DataCollatorForSeq2Seq`] to create a batch of examples. It will also *dyna
</tf>
</frameworkcontent>
## Fine-tune with Trainer
## Train
<frameworkcontent>
<pt>
Load T5 with [`AutoModelForSeq2SeqLM`]:
```py
@@ -156,18 +158,9 @@ At this point, only three steps remain:
>>> trainer.train()
```
## Fine-tune with TensorFlow
To fine-tune a model in TensorFlow is just as easy, with only a few differences.
<Tip>
If you aren't familiar with fine-tuning a model with Keras, take a look at the basic tutorial [here](../training#finetune-with-keras)!
</Tip>
Convert your datasets to the `tf.data.Dataset` format with [`to_tf_dataset`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.to_tf_dataset). Specify inputs and labels in `columns`, whether to shuffle the dataset order, batch size, and the data collator:
</pt>
<tf>
To fine-tune a model in TensorFlow, start by converting your datasets to the `tf.data.Dataset` format with [`to_tf_dataset`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.to_tf_dataset). Specify inputs and labels in `columns`, whether to shuffle the dataset order, batch size, and the data collator:
```py
>>> tf_train_set = tokenized_billsum["train"].to_tf_dataset(
@@ -185,6 +178,12 @@ Convert your datasets to the `tf.data.Dataset` format with [`to_tf_dataset`](htt
... )
```
<Tip>
If you aren't familiar with fine-tuning a model with Keras, take a look at the basic tutorial [here](training#finetune-with-keras)!
</Tip>
Set up an optimizer function, learning rate schedule, and some training hyperparameters:
```py
@@ -212,6 +211,8 @@ Call [`fit`](https://keras.io/api/models/model_training_apis/#fit-method) to fin
```py
>>> model.fit(x=tf_train_set, validation_data=tf_test_set, epochs=3)
```
</tf>
</frameworkcontent>
<Tip>