add CFG for .generate() (#24654)
This commit is contained in:
committed by
GitHub
parent
a6e6b1c622
commit
d533465150
@@ -65,6 +65,7 @@ else:
|
|||||||
"EncoderNoRepeatNGramLogitsProcessor",
|
"EncoderNoRepeatNGramLogitsProcessor",
|
||||||
"ExponentialDecayLengthPenalty",
|
"ExponentialDecayLengthPenalty",
|
||||||
"LogitNormalization",
|
"LogitNormalization",
|
||||||
|
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
|
||||||
]
|
]
|
||||||
_import_structure["stopping_criteria"] = [
|
_import_structure["stopping_criteria"] = [
|
||||||
"MaxNewTokensCriteria",
|
"MaxNewTokensCriteria",
|
||||||
@@ -188,6 +189,7 @@ if TYPE_CHECKING:
|
|||||||
TopKLogitsWarper,
|
TopKLogitsWarper,
|
||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
TypicalLogitsWarper,
|
TypicalLogitsWarper,
|
||||||
|
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||||
)
|
)
|
||||||
from .stopping_criteria import (
|
from .stopping_criteria import (
|
||||||
MaxLengthCriteria,
|
MaxLengthCriteria,
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import math
|
import math
|
||||||
from typing import Callable, Dict, Iterable, List, Tuple, Union
|
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -1334,3 +1334,119 @@ class AlternatingCodebooksLogitsProcessor(LogitsProcessor):
|
|||||||
scores[:, : self.semantic_vocab_size + self.codebook_size] = -float("inf")
|
scores[:, : self.semantic_vocab_size + self.codebook_size] = -float("inf")
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
|
||||||
|
r"""Logits processor for Classifier-Free Guidance (CFG). The processors
|
||||||
|
computes a weighted average across scores from prompt conditional and prompt unconditional (or negative) logits,
|
||||||
|
parameterized by the `guidance_scale`. The unconditional scores are computed internally by prompting `model` with
|
||||||
|
the `unconditional_ids` branch.
|
||||||
|
|
||||||
|
See [the paper](https://arxiv.org/abs/2306.17806) for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
guidance_scale (`float`):
|
||||||
|
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
|
||||||
|
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
|
||||||
|
prompt, usually at the expense of poorer quality.
|
||||||
|
unconditional_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Indices of input sequence tokens in the vocabulary for the unconditional branch. If unset, will default to
|
||||||
|
the last token of the prompt.
|
||||||
|
unconditional_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, **optional**):
|
||||||
|
Attention mask for unconditional_ids.
|
||||||
|
model (`PreTrainedModel`):
|
||||||
|
The model computing the unconditional scores. Supposedly the same as the one computing the conditional
|
||||||
|
scores. Both models must use the same tokenizer.
|
||||||
|
smooth_factor (`float`, **optional**):
|
||||||
|
The interpolation weight for CFG Rescale. 1 means no rescaling, 0 reduces to the conditional scores without
|
||||||
|
CFG. Turn it lower if the output degenerates.
|
||||||
|
use_cache (`bool`, **optional**):
|
||||||
|
Whether to cache key/values during the negative prompt forward pass.
|
||||||
|
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
|
||||||
|
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||||
|
>>> inputs = tokenizer(["Today, a dragon flew over Paris, France,"], return_tensors="pt")
|
||||||
|
>>> out = model.generate(inputs["input_ids"], guidance_scale=1.5)
|
||||||
|
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
|
||||||
|
The dragon flew over Paris, France, landing in Lyon, a city of a few million. Dragon-flying was a new form of
|
||||||
|
transport, and the dragon was the first in Europe.
|
||||||
|
|
||||||
|
>>> # with a negative prompt
|
||||||
|
>>> neg_inputs = tokenizer(["A very happy event happened,"], return_tensors="pt")
|
||||||
|
>>> out = model.generate(inputs["input_ids"], guidance_scale=2, negative_prompt_ids=neg_inputs["input_ids"])
|
||||||
|
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
|
||||||
|
The dragon flew over Paris, France, crashing into Notre Dame Cathedral in the French capital killing at least 127
|
||||||
|
people and injuring more than 350.
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
guidance_scale: float,
|
||||||
|
model,
|
||||||
|
unconditional_ids: Optional[torch.LongTensor] = None,
|
||||||
|
unconditional_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = True,
|
||||||
|
):
|
||||||
|
self.guidance_scale = guidance_scale
|
||||||
|
self.model = model
|
||||||
|
self.unconditional_context = {
|
||||||
|
"input_ids": unconditional_ids,
|
||||||
|
"attention_mask": unconditional_attention_mask,
|
||||||
|
"use_cache": use_cache,
|
||||||
|
"past_key_values": None,
|
||||||
|
"first_pass": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_unconditional_logits(self, input_ids):
|
||||||
|
if self.unconditional_context["first_pass"]:
|
||||||
|
if self.unconditional_context["input_ids"] is None:
|
||||||
|
self.unconditional_context["input_ids"] = input_ids[:, -1:]
|
||||||
|
if self.unconditional_context["attention_mask"] is None:
|
||||||
|
self.unconditional_context["attention_mask"] = torch.ones_like(
|
||||||
|
self.unconditional_context["input_ids"], dtype=torch.long
|
||||||
|
)
|
||||||
|
input_ids = self.unconditional_context["input_ids"]
|
||||||
|
attention_mask = self.unconditional_context["attention_mask"]
|
||||||
|
self.unconditional_context["first_pass"] = False
|
||||||
|
else:
|
||||||
|
attention_mask = torch.cat(
|
||||||
|
[
|
||||||
|
self.unconditional_context["attention_mask"],
|
||||||
|
torch.ones_like(input_ids[:, -1:], dtype=torch.long),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
if not self.unconditional_context["use_cache"]:
|
||||||
|
input_ids = torch.cat([self.unconditional_context["input_ids"], input_ids[:, -1:]], dim=1)
|
||||||
|
else:
|
||||||
|
input_ids = input_ids[:, -1:]
|
||||||
|
self.unconditional_context["input_ids"] = input_ids
|
||||||
|
self.unconditional_context["attention_mask"] = attention_mask
|
||||||
|
|
||||||
|
out = self.model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
use_cache=self.unconditional_context["use_cache"],
|
||||||
|
past_key_values=self.unconditional_context["past_key_values"],
|
||||||
|
)
|
||||||
|
self.unconditional_context["past_key_values"] = out.get("past_key_values", None)
|
||||||
|
|
||||||
|
return out.logits
|
||||||
|
|
||||||
|
def __call__(self, input_ids, scores):
|
||||||
|
scores = torch.nn.functional.log_softmax(scores, dim=-1)
|
||||||
|
if self.guidance_scale == 1:
|
||||||
|
return scores
|
||||||
|
|
||||||
|
logits = self.get_unconditional_logits(input_ids)
|
||||||
|
|
||||||
|
unconditional_logits = torch.nn.functional.log_softmax(logits[:, -1], dim=-1)
|
||||||
|
out = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits
|
||||||
|
return out
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
|
|||||||
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||||
from .configuration_utils import GenerationConfig
|
from .configuration_utils import GenerationConfig
|
||||||
from .logits_process import (
|
from .logits_process import (
|
||||||
ClassifierFreeGuidanceLogitsProcessor,
|
|
||||||
EncoderNoRepeatNGramLogitsProcessor,
|
EncoderNoRepeatNGramLogitsProcessor,
|
||||||
EncoderRepetitionPenaltyLogitsProcessor,
|
EncoderRepetitionPenaltyLogitsProcessor,
|
||||||
EpsilonLogitsWarper,
|
EpsilonLogitsWarper,
|
||||||
@@ -64,6 +63,7 @@ from .logits_process import (
|
|||||||
TopKLogitsWarper,
|
TopKLogitsWarper,
|
||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
TypicalLogitsWarper,
|
TypicalLogitsWarper,
|
||||||
|
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||||
)
|
)
|
||||||
from .stopping_criteria import (
|
from .stopping_criteria import (
|
||||||
MaxLengthCriteria,
|
MaxLengthCriteria,
|
||||||
@@ -893,6 +893,9 @@ class GenerationMixin:
|
|||||||
encoder_input_ids: torch.LongTensor,
|
encoder_input_ids: torch.LongTensor,
|
||||||
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
|
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
|
||||||
logits_processor: Optional[LogitsProcessorList],
|
logits_processor: Optional[LogitsProcessorList],
|
||||||
|
model_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
negative_prompt_ids: Optional[torch.Tensor] = None,
|
||||||
|
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||||
) -> LogitsProcessorList:
|
) -> LogitsProcessorList:
|
||||||
"""
|
"""
|
||||||
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
|
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
|
||||||
@@ -901,6 +904,16 @@ class GenerationMixin:
|
|||||||
# instantiate processors list
|
# instantiate processors list
|
||||||
processors = LogitsProcessorList()
|
processors = LogitsProcessorList()
|
||||||
|
|
||||||
|
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
|
||||||
|
processors.append(
|
||||||
|
UnbatchedClassifierFreeGuidanceLogitsProcessor(
|
||||||
|
generation_config.guidance_scale,
|
||||||
|
self,
|
||||||
|
unconditional_ids=negative_prompt_ids,
|
||||||
|
unconditional_attention_mask=negative_prompt_attention_mask,
|
||||||
|
use_cache=model_kwargs["use_cache"],
|
||||||
|
)
|
||||||
|
)
|
||||||
if generation_config.sequence_bias is not None:
|
if generation_config.sequence_bias is not None:
|
||||||
processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias))
|
processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias))
|
||||||
|
|
||||||
@@ -998,8 +1011,6 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
if generation_config.forced_decoder_ids is not None:
|
if generation_config.forced_decoder_ids is not None:
|
||||||
processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids))
|
processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids))
|
||||||
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
|
|
||||||
processors.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
|
|
||||||
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
||||||
# `LogitNormalization` should always be the last logit processor, when present
|
# `LogitNormalization` should always be the last logit processor, when present
|
||||||
if generation_config.renormalize_logits is True:
|
if generation_config.renormalize_logits is True:
|
||||||
@@ -1251,6 +1262,8 @@ class GenerationMixin:
|
|||||||
synced_gpus: Optional[bool] = None,
|
synced_gpus: Optional[bool] = None,
|
||||||
assistant_model: Optional["PreTrainedModel"] = None,
|
assistant_model: Optional["PreTrainedModel"] = None,
|
||||||
streamer: Optional["BaseStreamer"] = None,
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
|
negative_prompt_ids: Optional[torch.Tensor] = None,
|
||||||
|
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||||
r"""
|
r"""
|
||||||
@@ -1308,6 +1321,11 @@ class GenerationMixin:
|
|||||||
streamer (`BaseStreamer`, *optional*):
|
streamer (`BaseStreamer`, *optional*):
|
||||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||||
|
negative_prompt_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
The negative prompt needed for some processors such as CFG. The batch size must match the input batch
|
||||||
|
size. This is an experimental feature, subject to breaking API changes in future versions.
|
||||||
|
negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Attention_mask for `negative_prompt_ids`.
|
||||||
kwargs (`Dict[str, Any]`, *optional*):
|
kwargs (`Dict[str, Any]`, *optional*):
|
||||||
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
||||||
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
||||||
@@ -1511,6 +1529,9 @@ class GenerationMixin:
|
|||||||
encoder_input_ids=inputs_tensor,
|
encoder_input_ids=inputs_tensor,
|
||||||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
negative_prompt_ids=negative_prompt_ids,
|
||||||
|
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 9. prepare stopping criteria
|
# 9. prepare stopping criteria
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ if is_torch_available():
|
|||||||
TopKLogitsWarper,
|
TopKLogitsWarper,
|
||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
TypicalLogitsWarper,
|
TypicalLogitsWarper,
|
||||||
|
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -743,3 +744,54 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
self.assertTrue(normalized_scores.sum(dim=-1).allclose(ones))
|
self.assertTrue(normalized_scores.sum(dim=-1).allclose(ones))
|
||||||
|
|
||||||
self.assertTrue(normalized_scores.allclose(scores.softmax(dim=-1)))
|
self.assertTrue(normalized_scores.allclose(scores.softmax(dim=-1)))
|
||||||
|
|
||||||
|
def test_classifier_free_guidance(self):
|
||||||
|
class Namespace(dict):
|
||||||
|
pass
|
||||||
|
|
||||||
|
logits_uncond = torch.tensor([[[1.0, 0, 1.5]]])
|
||||||
|
logits_cond = torch.tensor([[[1.0, 1.0, 1.0]]])
|
||||||
|
|
||||||
|
def dummy_model(input_ids, attention_mask, use_cache=True, past_key_values=None):
|
||||||
|
out = Namespace()
|
||||||
|
out.logits = logits_uncond
|
||||||
|
out.past_key_values = None
|
||||||
|
return out
|
||||||
|
|
||||||
|
def lsm(x):
|
||||||
|
return torch.nn.functional.log_softmax(x, dim=-1)
|
||||||
|
|
||||||
|
# explicit unconditional prompt + attention mask
|
||||||
|
input_ids = torch.LongTensor([[0]])
|
||||||
|
cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(
|
||||||
|
1.5, dummy_model, input_ids, torch.ones_like(input_ids, dtype=torch.long)
|
||||||
|
)
|
||||||
|
out = cfg(input_ids, logits_cond)[0, -1]
|
||||||
|
|
||||||
|
res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1]
|
||||||
|
|
||||||
|
self.assertAlmostEqual(out[0].item(), res[0].item())
|
||||||
|
self.assertAlmostEqual(out[1].item(), res[1].item())
|
||||||
|
self.assertAlmostEqual(out[2].item(), res[2].item())
|
||||||
|
|
||||||
|
# explicit unconditional prompt
|
||||||
|
input_ids = torch.LongTensor([[0]])
|
||||||
|
cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(1.5, dummy_model, input_ids)
|
||||||
|
out = cfg(input_ids, logits_cond)[0, -1]
|
||||||
|
|
||||||
|
res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1]
|
||||||
|
|
||||||
|
self.assertAlmostEqual(out[0].item(), res[0].item())
|
||||||
|
self.assertAlmostEqual(out[1].item(), res[1].item())
|
||||||
|
self.assertAlmostEqual(out[2].item(), res[2].item())
|
||||||
|
|
||||||
|
# all implicit
|
||||||
|
input_ids = torch.LongTensor([[0]])
|
||||||
|
cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(1.5, dummy_model)
|
||||||
|
out = cfg(input_ids, logits_cond)[0, -1]
|
||||||
|
|
||||||
|
res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1]
|
||||||
|
|
||||||
|
self.assertAlmostEqual(out[0].item(), res[0].item())
|
||||||
|
self.assertAlmostEqual(out[1].item(), res[1].item())
|
||||||
|
self.assertAlmostEqual(out[2].item(), res[2].item())
|
||||||
|
|||||||
@@ -2585,6 +2585,46 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_cfg_mixin(self):
|
||||||
|
model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||||
|
|
||||||
|
input = tokenizer(["The dragon flew over Paris,"], return_tensors="pt", return_attention_mask=True)
|
||||||
|
input["input_ids"] = input["input_ids"].to(torch_device)
|
||||||
|
input["attention_mask"] = input["attention_mask"].to(torch_device)
|
||||||
|
|
||||||
|
outputs = model.generate(**input, max_new_tokens=32, guidance_scale=1.5)
|
||||||
|
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
generated_text,
|
||||||
|
[
|
||||||
|
"The dragon flew over Paris, landing in the Rue de la Bastille. The crowd was so excited "
|
||||||
|
'that they had to leave the city.\n\n"We\'re going to Paris!"\n'
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
neg = tokenizer(["France,"], return_tensors="pt", return_attention_mask=True)
|
||||||
|
neg["input_ids"] = neg["input_ids"].to(torch_device)
|
||||||
|
neg["attention_mask"] = neg["attention_mask"].to(torch_device)
|
||||||
|
outputs = model.generate(
|
||||||
|
**input,
|
||||||
|
max_new_tokens=32,
|
||||||
|
guidance_scale=1.5,
|
||||||
|
negative_prompt_ids=neg["input_ids"],
|
||||||
|
negative_prompt_attention_mask=neg["attention_mask"],
|
||||||
|
)
|
||||||
|
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
generated_text,
|
||||||
|
[
|
||||||
|
'The dragon flew over Paris, landing on the pavement.\n\n"Paris!"\n\n"Paris!"\n\n"'
|
||||||
|
'Paris!"\n\n"Paris!"\n\n"Paris!"\n\n'
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_constrained_beam_search_example_translation_mixin(self):
|
def test_constrained_beam_search_example_translation_mixin(self):
|
||||||
# PT-only test: TF doesn't have constrained beam search
|
# PT-only test: TF doesn't have constrained beam search
|
||||||
|
|||||||
Reference in New Issue
Block a user