Generation: fix handling of special tokens (#31254)

* fix special tokens in generatioon

* fix test

* add warning

* fix the check

* warn once

* fix
This commit is contained in:
Raushan Turganbay
2024-06-06 15:21:32 +05:00
committed by GitHub
parent 7729b77478
commit 5fabd1e83b
2 changed files with 29 additions and 30 deletions

View File

@@ -161,6 +161,7 @@ class GenerationIntegrationTestsMixin:
tokenizer.pad_token = tokenizer.eos_token
model = model_cls.from_pretrained("distilbert/distilgpt2")
model.generation_config.eos_token_id = None
input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids
if is_pt:
model = model.to(torch_device)
@@ -170,7 +171,6 @@ class GenerationIntegrationTestsMixin:
input_ids=input_ids,
max_new_tokens=5,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=None,
return_dict_in_generate=True,
output_scores=True,
)
@@ -197,6 +197,7 @@ class GenerationIntegrationTestsMixin:
tokenizer.pad_token = tokenizer.eos_token
model = model_cls.from_pretrained("distilbert/distilgpt2")
model.generation_config.eos_token_id = None
input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids
if is_pt:
model = model.to(torch_device)
@@ -206,7 +207,6 @@ class GenerationIntegrationTestsMixin:
input_ids=input_ids,
max_new_tokens=5,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=None,
return_dict_in_generate=True,
output_scores=True,
)