add CFG for .generate() (#24654)

This commit is contained in:
Guillaume "Vermeille" Sanchez
2023-08-06 21:15:24 +02:00
committed by GitHub
parent a6e6b1c622
commit d533465150
5 changed files with 235 additions and 4 deletions

View File

@@ -51,6 +51,7 @@ if is_torch_available():
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
)
@@ -743,3 +744,54 @@ class LogitsProcessorTest(unittest.TestCase):
self.assertTrue(normalized_scores.sum(dim=-1).allclose(ones))
self.assertTrue(normalized_scores.allclose(scores.softmax(dim=-1)))
def test_classifier_free_guidance(self):
class Namespace(dict):
pass
logits_uncond = torch.tensor([[[1.0, 0, 1.5]]])
logits_cond = torch.tensor([[[1.0, 1.0, 1.0]]])
def dummy_model(input_ids, attention_mask, use_cache=True, past_key_values=None):
out = Namespace()
out.logits = logits_uncond
out.past_key_values = None
return out
def lsm(x):
return torch.nn.functional.log_softmax(x, dim=-1)
# explicit unconditional prompt + attention mask
input_ids = torch.LongTensor([[0]])
cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(
1.5, dummy_model, input_ids, torch.ones_like(input_ids, dtype=torch.long)
)
out = cfg(input_ids, logits_cond)[0, -1]
res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1]
self.assertAlmostEqual(out[0].item(), res[0].item())
self.assertAlmostEqual(out[1].item(), res[1].item())
self.assertAlmostEqual(out[2].item(), res[2].item())
# explicit unconditional prompt
input_ids = torch.LongTensor([[0]])
cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(1.5, dummy_model, input_ids)
out = cfg(input_ids, logits_cond)[0, -1]
res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1]
self.assertAlmostEqual(out[0].item(), res[0].item())
self.assertAlmostEqual(out[1].item(), res[1].item())
self.assertAlmostEqual(out[2].item(), res[2].item())
# all implicit
input_ids = torch.LongTensor([[0]])
cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(1.5, dummy_model)
out = cfg(input_ids, logits_cond)[0, -1]
res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1]
self.assertAlmostEqual(out[0].item(), res[0].item())
self.assertAlmostEqual(out[1].item(), res[1].item())
self.assertAlmostEqual(out[2].item(), res[2].item())