add CFG for .generate() (#24654)

This commit is contained in:
Guillaume "Vermeille" Sanchez
2023-08-06 21:15:24 +02:00
committed by GitHub
parent a6e6b1c622
commit d533465150
5 changed files with 235 additions and 4 deletions

View File

@@ -2585,6 +2585,46 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
],
)
@slow
def test_cfg_mixin(self):
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input = tokenizer(["The dragon flew over Paris,"], return_tensors="pt", return_attention_mask=True)
input["input_ids"] = input["input_ids"].to(torch_device)
input["attention_mask"] = input["attention_mask"].to(torch_device)
outputs = model.generate(**input, max_new_tokens=32, guidance_scale=1.5)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
"The dragon flew over Paris, landing in the Rue de la Bastille. The crowd was so excited "
'that they had to leave the city.\n\n"We\'re going to Paris!"\n'
],
)
neg = tokenizer(["France,"], return_tensors="pt", return_attention_mask=True)
neg["input_ids"] = neg["input_ids"].to(torch_device)
neg["attention_mask"] = neg["attention_mask"].to(torch_device)
outputs = model.generate(
**input,
max_new_tokens=32,
guidance_scale=1.5,
negative_prompt_ids=neg["input_ids"],
negative_prompt_attention_mask=neg["attention_mask"],
)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertListEqual(
generated_text,
[
'The dragon flew over Paris, landing on the pavement.\n\n"Paris!"\n\n"Paris!"\n\n"'
'Paris!"\n\n"Paris!"\n\n"Paris!"\n\n'
],
)
@slow
def test_constrained_beam_search_example_translation_mixin(self):
# PT-only test: TF doesn't have constrained beam search