Add custom stop token ids for generation (#20727)

* Add StopIdStoppingCriteria

* add a working test for stop id criteria

* add to global scope

* add stop_ids to generate

* add pipeline test

* use tokenizer encode in test

* add test to generation utils

* reformat

* fixup

* make-fix-copies

* rename to stop_token_id

* use stop_tokens instead

* add to text to text generation

* make fixup

* make repo-consistency

* Add support for list of ints for eos_token_id inside generation/utils.py

* Instead of having if elses, cast the eos_token_id into a List[int]

* Add List[int] support for logits_process.py

* add List[int] for beam_search.py

* add List[int] for forced_eos_token_id

* revert stop token id stopping criteria changes

* make fixup

* fix tests

* add eos_token_id to generation/utils.py and added tests test_utils.py

* add eos_token_id type hints and fix for pad tokens

* add comments

* remove some prints and remove forced false test

* fix

* put back test_stop_sequence_stopping_criteria

* remove unused import and make fixup

* add a none check

* update docstring

* add more docstring for list ints

* make fixup
This commit is contained in:
Motoki Wu
2023-01-03 12:18:24 -08:00
committed by GitHub
parent cd918492c6
commit 45da7cec5a
5 changed files with 210 additions and 64 deletions

View File

@@ -17,7 +17,7 @@
import inspect
import unittest
from transformers import is_torch_available
from transformers import is_torch_available, pipeline
from transformers.testing_utils import require_torch, slow, torch_device
from ..test_modeling_common import floats_tensor, ids_tensor
@@ -39,7 +39,6 @@ if is_torch_available():
SpeechEncoderDecoderModel,
T5ForConditionalGeneration,
VisionEncoderDecoderModel,
pipeline,
top_k_top_p_filtering,
)
from transformers.generation import (
@@ -91,8 +90,9 @@ class GenerationTesterMixin:
max_length = input_ids.shape[-1] + 3
if config.eos_token_id is not None and config.pad_token_id is None:
# hack to allow generate for models such as GPT2 as is done in `generate()`
config.pad_token_id = config.eos_token_id
if isinstance(config.eos_token_id, int):
config.eos_token_id = [config.eos_token_id]
config.pad_token_id = config.eos_token_id[0]
# TransfoXL has no attention mask
if "transfoxl" in config.__class__.__name__.lower():
attention_mask = None
@@ -3025,3 +3025,100 @@ class GenerationIntegrationTests(unittest.TestCase):
# However, valid model_kwargs are accepted
valid_model_kwargs = {"attention_mask": torch.zeros_like(input_ids)}
model.generate(input_ids, **valid_model_kwargs)
def test_eos_token_id_int_and_list_greedy_search(self):
generation_kwargs = {
"do_sample": False,
"num_beams": 1,
}
expectation = 13
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = """Hello, my dog is cute and"""
tokens = tokenizer(text, return_tensors="pt")
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
torch.manual_seed(0)
eos_token_id = 873
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
torch.manual_seed(0)
eos_token_id = [873]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
def test_eos_token_id_int_and_list_contrastive_search(self):
generation_kwargs = {
"do_sample": False,
"num_beams": 1,
"penalty_alpha": 0.6,
"top_k": 4,
}
expectation = 17
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = """Hello, my dog is cute and"""
tokens = tokenizer(text, return_tensors="pt")
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
torch.manual_seed(0)
eos_token_id = 225
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
torch.manual_seed(0)
eos_token_id = [225]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
def test_eos_token_id_int_and_list_top_k_top_sampling(self):
generation_kwargs = {
"do_sample": True,
"num_beams": 1,
"top_p": 0.7,
"top_k": 10,
"temperature": 0.7,
}
expectation = 15
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = """Hello, my dog is cute and"""
tokens = tokenizer(text, return_tensors="pt")
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
torch.manual_seed(0)
eos_token_id = 846
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
torch.manual_seed(0)
eos_token_id = [846]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
def test_eos_token_id_int_and_list_beam_search(self):
generation_kwargs = {
"do_sample": False,
"num_beams": 3,
}
expectation = 13
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = """Hello, my dog is cute and"""
tokens = tokenizer(text, return_tensors="pt")
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
torch.manual_seed(0)
eos_token_id = 873
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))
torch.manual_seed(0)
eos_token_id = [873]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0]))