[trainer] handle case where EOS token is None in generation_config (#40127)

* handle case where EOS token is None in gen config

* update eli5 dataset
This commit is contained in:
Joao Gante
2025-08-13 15:57:17 +01:00
committed by GitHub
parent 8ef5cd6579
commit 11537c3e0c
2 changed files with 9 additions and 6 deletions

View File

@@ -29,7 +29,7 @@ the left. This means the model cannot see future tokens. GPT-2 is an example of
This guide will show you how to:
1. Finetune [DistilGPT2](https://huggingface.co/distilbert/distilgpt2) on the [r/askscience](https://www.reddit.com/r/askscience/) subset of the [ELI5](https://huggingface.co/datasets/eli5) dataset.
1. Finetune [DistilGPT2](https://huggingface.co/distilbert/distilgpt2) on the [r/askscience](https://www.reddit.com/r/askscience/) subset of the [ELI5](https://huggingface.co/datasets/dany0407/eli5_category) dataset.
2. Use your finetuned model for inference.
<Tip>
@@ -54,12 +54,12 @@ We encourage you to log in to your Hugging Face account so you can upload and sh
## Load ELI5 dataset
Start by loading the first 5000 examples from the [ELI5-Category](https://huggingface.co/datasets/eli5_category) dataset with the 🤗 Datasets library. This'll give you a chance to experiment and make sure everything works before spending more time training on the full dataset.
Start by loading the first 5000 examples from the [ELI5-Category](https://huggingface.co/datasets/dany0407/eli5_category) dataset with the 🤗 Datasets library. This'll give you a chance to experiment and make sure everything works before spending more time training on the full dataset.
```py
>>> from datasets import load_dataset
>>> eli5 = load_dataset("eli5_category", split="train[:5000]")
>>> eli5 = load_dataset("dany0407/eli5_category", split="train[:5000]")
```
Split the dataset's `train` split into a train and test set with the [`~datasets.Dataset.train_test_split`] method:

View File

@@ -943,7 +943,9 @@ class Trainer:
# The generation config may hold more than one EOS token. We preserve the original EOS tokens: any of the
# EOS tokens defined here will halt generation.
if model_has_generation_config:
all_eos_tokens = [tokenizer.eos_token_id] + list(self.model.generation_config.eos_token_id)
all_eos_tokens = [tokenizer.eos_token_id]
if self.model.generation_config.eos_token_id is not None:
all_eos_tokens += list(self.model.generation_config.eos_token_id)
self.model.generation_config.eos_token_id = [token for token in all_eos_tokens if token is not None]
# 2 - Align BOS
@@ -971,8 +973,9 @@ class Trainer:
# 4 - Warn users about the changes
if len(updated_tokens) > 0:
logger.warning(
"The tokenizer has new special tokens that are also defined in the model configs. The model "
f"configs were aligned accordingly. Updated tokens: {updated_tokens}"
"The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. "
"The model config and generation config were aligned accordingly, being updated with the tokenizer's "
f"values. Updated tokens: {updated_tokens}."
)
def _set_signature_columns_if_needed(self):