add CFG for .generate() (#24654)
This commit is contained in:
committed by
GitHub
parent
a6e6b1c622
commit
d533465150
@@ -51,6 +51,7 @@ if is_torch_available():
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
)
|
||||
|
||||
|
||||
@@ -743,3 +744,54 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
self.assertTrue(normalized_scores.sum(dim=-1).allclose(ones))
|
||||
|
||||
self.assertTrue(normalized_scores.allclose(scores.softmax(dim=-1)))
|
||||
|
||||
def test_classifier_free_guidance(self):
|
||||
class Namespace(dict):
|
||||
pass
|
||||
|
||||
logits_uncond = torch.tensor([[[1.0, 0, 1.5]]])
|
||||
logits_cond = torch.tensor([[[1.0, 1.0, 1.0]]])
|
||||
|
||||
def dummy_model(input_ids, attention_mask, use_cache=True, past_key_values=None):
|
||||
out = Namespace()
|
||||
out.logits = logits_uncond
|
||||
out.past_key_values = None
|
||||
return out
|
||||
|
||||
def lsm(x):
|
||||
return torch.nn.functional.log_softmax(x, dim=-1)
|
||||
|
||||
# explicit unconditional prompt + attention mask
|
||||
input_ids = torch.LongTensor([[0]])
|
||||
cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(
|
||||
1.5, dummy_model, input_ids, torch.ones_like(input_ids, dtype=torch.long)
|
||||
)
|
||||
out = cfg(input_ids, logits_cond)[0, -1]
|
||||
|
||||
res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1]
|
||||
|
||||
self.assertAlmostEqual(out[0].item(), res[0].item())
|
||||
self.assertAlmostEqual(out[1].item(), res[1].item())
|
||||
self.assertAlmostEqual(out[2].item(), res[2].item())
|
||||
|
||||
# explicit unconditional prompt
|
||||
input_ids = torch.LongTensor([[0]])
|
||||
cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(1.5, dummy_model, input_ids)
|
||||
out = cfg(input_ids, logits_cond)[0, -1]
|
||||
|
||||
res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1]
|
||||
|
||||
self.assertAlmostEqual(out[0].item(), res[0].item())
|
||||
self.assertAlmostEqual(out[1].item(), res[1].item())
|
||||
self.assertAlmostEqual(out[2].item(), res[2].item())
|
||||
|
||||
# all implicit
|
||||
input_ids = torch.LongTensor([[0]])
|
||||
cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(1.5, dummy_model)
|
||||
out = cfg(input_ids, logits_cond)[0, -1]
|
||||
|
||||
res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1]
|
||||
|
||||
self.assertAlmostEqual(out[0].item(), res[0].item())
|
||||
self.assertAlmostEqual(out[1].item(), res[1].item())
|
||||
self.assertAlmostEqual(out[2].item(), res[2].item())
|
||||
|
||||
@@ -2585,6 +2585,46 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_cfg_mixin(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
|
||||
input = tokenizer(["The dragon flew over Paris,"], return_tensors="pt", return_attention_mask=True)
|
||||
input["input_ids"] = input["input_ids"].to(torch_device)
|
||||
input["attention_mask"] = input["attention_mask"].to(torch_device)
|
||||
|
||||
outputs = model.generate(**input, max_new_tokens=32, guidance_scale=1.5)
|
||||
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(
|
||||
generated_text,
|
||||
[
|
||||
"The dragon flew over Paris, landing in the Rue de la Bastille. The crowd was so excited "
|
||||
'that they had to leave the city.\n\n"We\'re going to Paris!"\n'
|
||||
],
|
||||
)
|
||||
|
||||
neg = tokenizer(["France,"], return_tensors="pt", return_attention_mask=True)
|
||||
neg["input_ids"] = neg["input_ids"].to(torch_device)
|
||||
neg["attention_mask"] = neg["attention_mask"].to(torch_device)
|
||||
outputs = model.generate(
|
||||
**input,
|
||||
max_new_tokens=32,
|
||||
guidance_scale=1.5,
|
||||
negative_prompt_ids=neg["input_ids"],
|
||||
negative_prompt_attention_mask=neg["attention_mask"],
|
||||
)
|
||||
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(
|
||||
generated_text,
|
||||
[
|
||||
'The dragon flew over Paris, landing on the pavement.\n\n"Paris!"\n\n"Paris!"\n\n"'
|
||||
'Paris!"\n\n"Paris!"\n\n"Paris!"\n\n'
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search_example_translation_mixin(self):
|
||||
# PT-only test: TF doesn't have constrained beam search
|
||||
|
||||
Reference in New Issue
Block a user