Add custom stop token ids for generation (#20727)
* Add StopIdStoppingCriteria * add a working test for stop id criteria * add to global scope * add stop_ids to generate * add pipeline test * use tokenizer encode in test * add test to generation utils * reformat * fixup * make-fix-copies * rename to stop_token_id * use stop_tokens instead * add to text to text generation * make fixup * make repo-consistency * Add support for list of ints for eos_token_id inside generation/utils.py * Instead of having if elses, cast the eos_token_id into a List[int] * Add List[int] support for logits_process.py * add List[int] for beam_search.py * add List[int] for forced_eos_token_id * revert stop token id stopping criteria changes * make fixup * fix tests * add eos_token_id to generation/utils.py and added tests test_utils.py * add eos_token_id type hints and fix for pad tokens * add comments * remove some prints and remove forced false test * fix * put back test_stop_sequence_stopping_criteria * remove unused import and make fixup * add a none check * update docstring * add more docstring for list ints * make fixup
This commit is contained in:
@@ -17,7 +17,7 @@
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers import is_torch_available, pipeline
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ..test_modeling_common import floats_tensor, ids_tensor
|
||||
@@ -39,7 +39,6 @@ if is_torch_available():
|
||||
SpeechEncoderDecoderModel,
|
||||
T5ForConditionalGeneration,
|
||||
VisionEncoderDecoderModel,
|
||||
pipeline,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
from transformers.generation import (
|
||||
@@ -91,8 +90,9 @@ class GenerationTesterMixin:
|
||||
max_length = input_ids.shape[-1] + 3
|
||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||
config.pad_token_id = config.eos_token_id
|
||||
|
||||
if isinstance(config.eos_token_id, int):
|
||||
config.eos_token_id = [config.eos_token_id]
|
||||
config.pad_token_id = config.eos_token_id[0]
|
||||
# TransfoXL has no attention mask
|
||||
if "transfoxl" in config.__class__.__name__.lower():
|
||||
attention_mask = None
|
||||
@@ -3025,3 +3025,100 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
# However, valid model_kwargs are accepted
|
||||
valid_model_kwargs = {"attention_mask": torch.zeros_like(input_ids)}
|
||||
model.generate(input_ids, **valid_model_kwargs)
|
||||
|
||||
def test_eos_token_id_int_and_list_greedy_search(self):
|
||||
generation_kwargs = {
|
||||
"do_sample": False,
|
||||
"num_beams": 1,
|
||||
}
|
||||
expectation = 13
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
text = """Hello, my dog is cute and"""
|
||||
tokens = tokenizer(text, return_tensors="pt")
|
||||
|
||||
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
eos_token_id = 873
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
torch.manual_seed(0)
|
||||
eos_token_id = [873]
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
def test_eos_token_id_int_and_list_contrastive_search(self):
|
||||
generation_kwargs = {
|
||||
"do_sample": False,
|
||||
"num_beams": 1,
|
||||
"penalty_alpha": 0.6,
|
||||
"top_k": 4,
|
||||
}
|
||||
expectation = 17
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
text = """Hello, my dog is cute and"""
|
||||
tokens = tokenizer(text, return_tensors="pt")
|
||||
|
||||
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
eos_token_id = 225
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
torch.manual_seed(0)
|
||||
eos_token_id = [225]
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
def test_eos_token_id_int_and_list_top_k_top_sampling(self):
|
||||
generation_kwargs = {
|
||||
"do_sample": True,
|
||||
"num_beams": 1,
|
||||
"top_p": 0.7,
|
||||
"top_k": 10,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
expectation = 15
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
text = """Hello, my dog is cute and"""
|
||||
tokens = tokenizer(text, return_tensors="pt")
|
||||
|
||||
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
eos_token_id = 846
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
torch.manual_seed(0)
|
||||
eos_token_id = [846]
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
def test_eos_token_id_int_and_list_beam_search(self):
|
||||
generation_kwargs = {
|
||||
"do_sample": False,
|
||||
"num_beams": 3,
|
||||
}
|
||||
expectation = 13
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
text = """Hello, my dog is cute and"""
|
||||
tokens = tokenizer(text, return_tensors="pt")
|
||||
|
||||
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
eos_token_id = 873
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
torch.manual_seed(0)
|
||||
eos_token_id = [873]
|
||||
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
|
||||
self.assertTrue(expectation == len(generated_tokens[0]))
|
||||
|
||||
Reference in New Issue
Block a user