add CFG for .generate() (#24654)
This commit is contained in:
committed by
GitHub
parent
a6e6b1c622
commit
d533465150
@@ -65,6 +65,7 @@ else:
|
||||
"EncoderNoRepeatNGramLogitsProcessor",
|
||||
"ExponentialDecayLengthPenalty",
|
||||
"LogitNormalization",
|
||||
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
|
||||
]
|
||||
_import_structure["stopping_criteria"] = [
|
||||
"MaxNewTokensCriteria",
|
||||
@@ -188,6 +189,7 @@ if TYPE_CHECKING:
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
)
|
||||
from .stopping_criteria import (
|
||||
MaxLengthCriteria,
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from typing import Callable, Dict, Iterable, List, Tuple, Union
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -1334,3 +1334,119 @@ class AlternatingCodebooksLogitsProcessor(LogitsProcessor):
|
||||
scores[:, : self.semantic_vocab_size + self.codebook_size] = -float("inf")
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
|
||||
r"""Logits processor for Classifier-Free Guidance (CFG). The processors
|
||||
computes a weighted average across scores from prompt conditional and prompt unconditional (or negative) logits,
|
||||
parameterized by the `guidance_scale`. The unconditional scores are computed internally by prompting `model` with
|
||||
the `unconditional_ids` branch.
|
||||
|
||||
See [the paper](https://arxiv.org/abs/2306.17806) for more information.
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`):
|
||||
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
|
||||
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
|
||||
prompt, usually at the expense of poorer quality.
|
||||
unconditional_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Indices of input sequence tokens in the vocabulary for the unconditional branch. If unset, will default to
|
||||
the last token of the prompt.
|
||||
unconditional_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, **optional**):
|
||||
Attention mask for unconditional_ids.
|
||||
model (`PreTrainedModel`):
|
||||
The model computing the unconditional scores. Supposedly the same as the one computing the conditional
|
||||
scores. Both models must use the same tokenizer.
|
||||
smooth_factor (`float`, **optional**):
|
||||
The interpolation weight for CFG Rescale. 1 means no rescaling, 0 reduces to the conditional scores without
|
||||
CFG. Turn it lower if the output degenerates.
|
||||
use_cache (`bool`, **optional**):
|
||||
Whether to cache key/values during the negative prompt forward pass.
|
||||
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
>>> inputs = tokenizer(["Today, a dragon flew over Paris, France,"], return_tensors="pt")
|
||||
>>> out = model.generate(inputs["input_ids"], guidance_scale=1.5)
|
||||
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
|
||||
The dragon flew over Paris, France, landing in Lyon, a city of a few million. Dragon-flying was a new form of
|
||||
transport, and the dragon was the first in Europe.
|
||||
|
||||
>>> # with a negative prompt
|
||||
>>> neg_inputs = tokenizer(["A very happy event happened,"], return_tensors="pt")
|
||||
>>> out = model.generate(inputs["input_ids"], guidance_scale=2, negative_prompt_ids=neg_inputs["input_ids"])
|
||||
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
|
||||
The dragon flew over Paris, France, crashing into Notre Dame Cathedral in the French capital killing at least 127
|
||||
people and injuring more than 350.
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float,
|
||||
model,
|
||||
unconditional_ids: Optional[torch.LongTensor] = None,
|
||||
unconditional_attention_mask: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = True,
|
||||
):
|
||||
self.guidance_scale = guidance_scale
|
||||
self.model = model
|
||||
self.unconditional_context = {
|
||||
"input_ids": unconditional_ids,
|
||||
"attention_mask": unconditional_attention_mask,
|
||||
"use_cache": use_cache,
|
||||
"past_key_values": None,
|
||||
"first_pass": True,
|
||||
}
|
||||
|
||||
def get_unconditional_logits(self, input_ids):
|
||||
if self.unconditional_context["first_pass"]:
|
||||
if self.unconditional_context["input_ids"] is None:
|
||||
self.unconditional_context["input_ids"] = input_ids[:, -1:]
|
||||
if self.unconditional_context["attention_mask"] is None:
|
||||
self.unconditional_context["attention_mask"] = torch.ones_like(
|
||||
self.unconditional_context["input_ids"], dtype=torch.long
|
||||
)
|
||||
input_ids = self.unconditional_context["input_ids"]
|
||||
attention_mask = self.unconditional_context["attention_mask"]
|
||||
self.unconditional_context["first_pass"] = False
|
||||
else:
|
||||
attention_mask = torch.cat(
|
||||
[
|
||||
self.unconditional_context["attention_mask"],
|
||||
torch.ones_like(input_ids[:, -1:], dtype=torch.long),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
if not self.unconditional_context["use_cache"]:
|
||||
input_ids = torch.cat([self.unconditional_context["input_ids"], input_ids[:, -1:]], dim=1)
|
||||
else:
|
||||
input_ids = input_ids[:, -1:]
|
||||
self.unconditional_context["input_ids"] = input_ids
|
||||
self.unconditional_context["attention_mask"] = attention_mask
|
||||
|
||||
out = self.model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=self.unconditional_context["use_cache"],
|
||||
past_key_values=self.unconditional_context["past_key_values"],
|
||||
)
|
||||
self.unconditional_context["past_key_values"] = out.get("past_key_values", None)
|
||||
|
||||
return out.logits
|
||||
|
||||
def __call__(self, input_ids, scores):
|
||||
scores = torch.nn.functional.log_softmax(scores, dim=-1)
|
||||
if self.guidance_scale == 1:
|
||||
return scores
|
||||
|
||||
logits = self.get_unconditional_logits(input_ids)
|
||||
|
||||
unconditional_logits = torch.nn.functional.log_softmax(logits[:, -1], dim=-1)
|
||||
out = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits
|
||||
return out
|
||||
|
||||
@@ -38,7 +38,6 @@ from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
|
||||
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||
from .configuration_utils import GenerationConfig
|
||||
from .logits_process import (
|
||||
ClassifierFreeGuidanceLogitsProcessor,
|
||||
EncoderNoRepeatNGramLogitsProcessor,
|
||||
EncoderRepetitionPenaltyLogitsProcessor,
|
||||
EpsilonLogitsWarper,
|
||||
@@ -64,6 +63,7 @@ from .logits_process import (
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
)
|
||||
from .stopping_criteria import (
|
||||
MaxLengthCriteria,
|
||||
@@ -893,6 +893,9 @@ class GenerationMixin:
|
||||
encoder_input_ids: torch.LongTensor,
|
||||
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
|
||||
logits_processor: Optional[LogitsProcessorList],
|
||||
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
negative_prompt_ids: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> LogitsProcessorList:
|
||||
"""
|
||||
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
|
||||
@@ -901,6 +904,16 @@ class GenerationMixin:
|
||||
# instantiate processors list
|
||||
processors = LogitsProcessorList()
|
||||
|
||||
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
|
||||
processors.append(
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor(
|
||||
generation_config.guidance_scale,
|
||||
self,
|
||||
unconditional_ids=negative_prompt_ids,
|
||||
unconditional_attention_mask=negative_prompt_attention_mask,
|
||||
use_cache=model_kwargs["use_cache"],
|
||||
)
|
||||
)
|
||||
if generation_config.sequence_bias is not None:
|
||||
processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias))
|
||||
|
||||
@@ -998,8 +1011,6 @@ class GenerationMixin:
|
||||
)
|
||||
if generation_config.forced_decoder_ids is not None:
|
||||
processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids))
|
||||
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
|
||||
processors.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
|
||||
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
||||
# `LogitNormalization` should always be the last logit processor, when present
|
||||
if generation_config.renormalize_logits is True:
|
||||
@@ -1251,6 +1262,8 @@ class GenerationMixin:
|
||||
synced_gpus: Optional[bool] = None,
|
||||
assistant_model: Optional["PreTrainedModel"] = None,
|
||||
streamer: Optional["BaseStreamer"] = None,
|
||||
negative_prompt_ids: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||
r"""
|
||||
@@ -1308,6 +1321,11 @@ class GenerationMixin:
|
||||
streamer (`BaseStreamer`, *optional*):
|
||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||
negative_prompt_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
The negative prompt needed for some processors such as CFG. The batch size must match the input batch
|
||||
size. This is an experimental feature, subject to breaking API changes in future versions.
|
||||
negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Attention_mask for `negative_prompt_ids`.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
||||
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
||||
@@ -1511,6 +1529,9 @@ class GenerationMixin:
|
||||
encoder_input_ids=inputs_tensor,
|
||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||
logits_processor=logits_processor,
|
||||
model_kwargs=model_kwargs,
|
||||
negative_prompt_ids=negative_prompt_ids,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
)
|
||||
|
||||
# 9. prepare stopping criteria
|
||||
|
||||
@@ -51,6 +51,7 @@ if is_torch_available():
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
)
|
||||
|
||||
|
||||
@@ -743,3 +744,54 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
self.assertTrue(normalized_scores.sum(dim=-1).allclose(ones))
|
||||
|
||||
self.assertTrue(normalized_scores.allclose(scores.softmax(dim=-1)))
|
||||
|
||||
def test_classifier_free_guidance(self):
|
||||
class Namespace(dict):
|
||||
pass
|
||||
|
||||
logits_uncond = torch.tensor([[[1.0, 0, 1.5]]])
|
||||
logits_cond = torch.tensor([[[1.0, 1.0, 1.0]]])
|
||||
|
||||
def dummy_model(input_ids, attention_mask, use_cache=True, past_key_values=None):
|
||||
out = Namespace()
|
||||
out.logits = logits_uncond
|
||||
out.past_key_values = None
|
||||
return out
|
||||
|
||||
def lsm(x):
|
||||
return torch.nn.functional.log_softmax(x, dim=-1)
|
||||
|
||||
# explicit unconditional prompt + attention mask
|
||||
input_ids = torch.LongTensor([[0]])
|
||||
cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(
|
||||
1.5, dummy_model, input_ids, torch.ones_like(input_ids, dtype=torch.long)
|
||||
)
|
||||
out = cfg(input_ids, logits_cond)[0, -1]
|
||||
|
||||
res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1]
|
||||
|
||||
self.assertAlmostEqual(out[0].item(), res[0].item())
|
||||
self.assertAlmostEqual(out[1].item(), res[1].item())
|
||||
self.assertAlmostEqual(out[2].item(), res[2].item())
|
||||
|
||||
# explicit unconditional prompt
|
||||
input_ids = torch.LongTensor([[0]])
|
||||
cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(1.5, dummy_model, input_ids)
|
||||
out = cfg(input_ids, logits_cond)[0, -1]
|
||||
|
||||
res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1]
|
||||
|
||||
self.assertAlmostEqual(out[0].item(), res[0].item())
|
||||
self.assertAlmostEqual(out[1].item(), res[1].item())
|
||||
self.assertAlmostEqual(out[2].item(), res[2].item())
|
||||
|
||||
# all implicit
|
||||
input_ids = torch.LongTensor([[0]])
|
||||
cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(1.5, dummy_model)
|
||||
out = cfg(input_ids, logits_cond)[0, -1]
|
||||
|
||||
res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1]
|
||||
|
||||
self.assertAlmostEqual(out[0].item(), res[0].item())
|
||||
self.assertAlmostEqual(out[1].item(), res[1].item())
|
||||
self.assertAlmostEqual(out[2].item(), res[2].item())
|
||||
|
||||
@@ -2585,6 +2585,46 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_cfg_mixin(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
|
||||
input = tokenizer(["The dragon flew over Paris,"], return_tensors="pt", return_attention_mask=True)
|
||||
input["input_ids"] = input["input_ids"].to(torch_device)
|
||||
input["attention_mask"] = input["attention_mask"].to(torch_device)
|
||||
|
||||
outputs = model.generate(**input, max_new_tokens=32, guidance_scale=1.5)
|
||||
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(
|
||||
generated_text,
|
||||
[
|
||||
"The dragon flew over Paris, landing in the Rue de la Bastille. The crowd was so excited "
|
||||
'that they had to leave the city.\n\n"We\'re going to Paris!"\n'
|
||||
],
|
||||
)
|
||||
|
||||
neg = tokenizer(["France,"], return_tensors="pt", return_attention_mask=True)
|
||||
neg["input_ids"] = neg["input_ids"].to(torch_device)
|
||||
neg["attention_mask"] = neg["attention_mask"].to(torch_device)
|
||||
outputs = model.generate(
|
||||
**input,
|
||||
max_new_tokens=32,
|
||||
guidance_scale=1.5,
|
||||
negative_prompt_ids=neg["input_ids"],
|
||||
negative_prompt_attention_mask=neg["attention_mask"],
|
||||
)
|
||||
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
|
||||
self.assertListEqual(
|
||||
generated_text,
|
||||
[
|
||||
'The dragon flew over Paris, landing on the pavement.\n\n"Paris!"\n\n"Paris!"\n\n"'
|
||||
'Paris!"\n\n"Paris!"\n\n"Paris!"\n\n'
|
||||
],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_constrained_beam_search_example_translation_mixin(self):
|
||||
# PT-only test: TF doesn't have constrained beam search
|
||||
|
||||
Reference in New Issue
Block a user