add CFG for .generate() (#24654)
This commit is contained in:
committed by
GitHub
parent
a6e6b1c622
commit
d533465150
@@ -51,6 +51,7 @@ if is_torch_available():
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
)
|
||||
|
||||
|
||||
@@ -743,3 +744,54 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
self.assertTrue(normalized_scores.sum(dim=-1).allclose(ones))
|
||||
|
||||
self.assertTrue(normalized_scores.allclose(scores.softmax(dim=-1)))
|
||||
|
||||
def test_classifier_free_guidance(self):
|
||||
class Namespace(dict):
|
||||
pass
|
||||
|
||||
logits_uncond = torch.tensor([[[1.0, 0, 1.5]]])
|
||||
logits_cond = torch.tensor([[[1.0, 1.0, 1.0]]])
|
||||
|
||||
def dummy_model(input_ids, attention_mask, use_cache=True, past_key_values=None):
|
||||
out = Namespace()
|
||||
out.logits = logits_uncond
|
||||
out.past_key_values = None
|
||||
return out
|
||||
|
||||
def lsm(x):
|
||||
return torch.nn.functional.log_softmax(x, dim=-1)
|
||||
|
||||
# explicit unconditional prompt + attention mask
|
||||
input_ids = torch.LongTensor([[0]])
|
||||
cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(
|
||||
1.5, dummy_model, input_ids, torch.ones_like(input_ids, dtype=torch.long)
|
||||
)
|
||||
out = cfg(input_ids, logits_cond)[0, -1]
|
||||
|
||||
res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1]
|
||||
|
||||
self.assertAlmostEqual(out[0].item(), res[0].item())
|
||||
self.assertAlmostEqual(out[1].item(), res[1].item())
|
||||
self.assertAlmostEqual(out[2].item(), res[2].item())
|
||||
|
||||
# explicit unconditional prompt
|
||||
input_ids = torch.LongTensor([[0]])
|
||||
cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(1.5, dummy_model, input_ids)
|
||||
out = cfg(input_ids, logits_cond)[0, -1]
|
||||
|
||||
res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1]
|
||||
|
||||
self.assertAlmostEqual(out[0].item(), res[0].item())
|
||||
self.assertAlmostEqual(out[1].item(), res[1].item())
|
||||
self.assertAlmostEqual(out[2].item(), res[2].item())
|
||||
|
||||
# all implicit
|
||||
input_ids = torch.LongTensor([[0]])
|
||||
cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(1.5, dummy_model)
|
||||
out = cfg(input_ids, logits_cond)[0, -1]
|
||||
|
||||
res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1]
|
||||
|
||||
self.assertAlmostEqual(out[0].item(), res[0].item())
|
||||
self.assertAlmostEqual(out[1].item(), res[1].item())
|
||||
self.assertAlmostEqual(out[2].item(), res[2].item())
|
||||
|
||||
Reference in New Issue
Block a user