add CFG for .generate() (#24654)
This commit is contained in:
committed by
GitHub
parent
a6e6b1c622
commit
d533465150
@@ -2585,6 +2585,46 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_cfg_mixin(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
|
||||
input = tokenizer(["The dragon flew over Paris,"], return_tensors="pt", return_attention_mask=True)
|
||||
input["input_ids"] = input["input_ids"].to(torch_device)
|
||||
input["attention_mask"] = input["attention_mask"].to(torch_device)
|
||||
|
||||
outputs = model.generate(**input, max_new_tokens=32, guidance_scale=1.5)
|
||||
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(
|
||||
generated_text,
|
||||
[
|
||||
"The dragon flew over Paris, landing in the Rue de la Bastille. The crowd was so excited "
|
||||
'that they had to leave the city.\n\n"We\'re going to Paris!"\n'
|
||||
],
|
||||
)
|
||||
|
||||
neg = tokenizer(["France,"], return_tensors="pt", return_attention_mask=True)
|
||||
neg["input_ids"] = neg["input_ids"].to(torch_device)
|
||||
neg["attention_mask"] = neg["attention_mask"].to(torch_device)
|
||||
outputs = model.generate(
|
||||
**input,
|
||||
max_new_tokens=32,
|
||||
guidance_scale=1.5,
|
||||
negative_prompt_ids=neg["input_ids"],
|
||||
negative_prompt_attention_mask=neg["attention_mask"],
|
||||
)
|
||||
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(
|
||||
generated_text,
|
||||
[
|
||||
'The dragon flew over Paris, landing on the pavement.\n\n"Paris!"\n\n"Paris!"\n\n"'
|
||||
'Paris!"\n\n"Paris!"\n\n"Paris!"\n\n'
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search_example_translation_mixin(self):
|
||||
# PT-only test: TF doesn't have constrained beam search
|
||||
|
||||
Reference in New Issue
Block a user