From 865da84abb8d150cb33a913104186a071ea55073 Mon Sep 17 00:00:00 2001 From: Sherman Siu Date: Tue, 17 Jan 2023 13:04:32 -0500 Subject: [PATCH] Add Epsilon- and Eta-Sampling (#21121) * Add epsilon- and eta-sampling. Add epsilon- and eta-sampling, following the official code from https://github.com/john-hewitt/truncation-sampling and adapting to be more configurable, as required by Huggingface transformers. * Add unit tests for epsilon- and eta-sampling. * Black: fix code formatting. * Fix docstring spacing. * Clean up newlines. * Fix implementation bugs and their associated tests. * Remove epsilon- and eta-sampling parameters from PretrainedConfig. * Clarify and clean up the documentation. * Remove parameters for PretrainedConfig test. --- src/transformers/generation/__init__.py | 4 + .../generation/configuration_utils.py | 14 +++ src/transformers/generation/logits_process.py | 88 ++++++++++++++++++- src/transformers/generation/utils.py | 29 +++--- tests/generation/test_logits_process.py | 76 ++++++++++++++++ 5 files changed, 192 insertions(+), 19 deletions(-) diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index d0c3a32973..95d4c9fff7 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -43,6 +43,8 @@ else: "ConstrainedBeamSearchScorer", ] _import_structure["logits_process"] = [ + "EpsilonLogitsWarper", + "EtaLogitsWarper", "ForcedBOSTokenLogitsProcessor", "ForcedEOSTokenLogitsProcessor", "HammingDiversityLogitsProcessor", @@ -162,6 +164,8 @@ if TYPE_CHECKING: from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .logits_process import ( EncoderNoRepeatNGramLogitsProcessor, + EpsilonLogitsWarper, + EtaLogitsWarper, ExponentialDecayLengthPenalty, ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index a0c851ef74..aacb1d29f4 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -109,6 +109,18 @@ class GenerationConfig(PushToHubMixin): generated. If set to float < 1, the smallest set of the most locally typical tokens with probabilities that add up to `typical_p` or higher are kept for generation. See [this paper](https://arxiv.org/pdf/2202.00666.pdf) for more details. + epsilon_cutoff (`float`, *optional*, defaults to 0.0): + If set to float strictly between 0 and 1, only tokens with a conditional probability greater than + `epsilon_cutoff` will be sampled. In the paper, suggested values range from 3e-4 to 9e-4, depending on the + size of the model. See [Truncation Sampling as Language Model + Desmoothing](https://arxiv.org/abs/2210.15191) for more details. + eta_cutoff (`float`, *optional*, defaults to 0.0): + Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to float strictly between + 0 and 1, a token is only considered if it is greater than either `eta_cutoff` or `sqrt(eta_cutoff) * + exp(-entropy(softmax(next_token_logits)))`. The latter term is intuitively the expected next token + probability, scaled by `sqrt(eta_cutoff)`. In the paper, suggested values range from 3e-4 to 2e-3, + depending on the size of the model. See [Truncation Sampling as Language Model + Desmoothing](https://arxiv.org/abs/2210.15191) for more details. diversity_penalty (`float`, *optional*, defaults to 0.0): This value is subtracted from a beam's score if it generates a token same as any beam from other group at a particular time. Note that `diversity_penalty` is only effective if `group beam search` is enabled. @@ -223,6 +235,8 @@ class GenerationConfig(PushToHubMixin): self.top_k = kwargs.pop("top_k", 50) self.top_p = kwargs.pop("top_p", 1.0) self.typical_p = kwargs.pop("typical_p", 1.0) + self.epsilon_cutoff = kwargs.pop("epsilon_cutoff", 0.0) + self.eta_cutoff = kwargs.pop("eta_cutoff", 0.0) self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0) self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0) self.length_penalty = kwargs.pop("length_penalty", 1.0) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index e3c135a5e2..e5b360a1c8 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -138,7 +138,6 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor): """ def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: int): - for arg_name, arg_value in [ ("prompt_length_to_skip", prompt_length_to_skip), ("min_new_tokens", min_new_tokens), @@ -152,7 +151,6 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor): self.eos_token_id = eos_token_id def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip if new_tokens_length < self.min_new_tokens: scores[:, self.eos_token_id] = -float("inf") @@ -297,7 +295,6 @@ class TypicalLogitsWarper(LogitsWarper): self.min_tokens_to_keep = min_tokens_to_keep def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - # calculate entropy normalized = torch.nn.functional.log_softmax(scores, dim=-1) p = torch.exp(normalized) @@ -322,6 +319,90 @@ class TypicalLogitsWarper(LogitsWarper): return scores +class EpsilonLogitsWarper(LogitsWarper): + r""" + [`LogitsWarper`] that performs epsilon-sampling, i.e. restricting to tokens with `prob >= epsilon`. Takes the + largest min_tokens_to_keep tokens if no tokens satisfy this constraint. See [Truncation Sampling as Language Model + Desmoothing](https://arxiv.org/abs/2210.15191) for more information. + + Args: + epsilon (`float`): + If set to > 0, only the most tokens with probabilities `epsilon` or higher are kept for generation. + filter_value (`float`, *optional*, defaults to `-float("Inf")`): + All filtered values will be set to this float value. + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Minimum number of tokens that cannot be filtered. + """ + + def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + epsilon = float(epsilon) + if epsilon <= 0 or epsilon >= 1: + raise ValueError(f"`epsilon_cutoff` has to be a float > 0 and < 1, but is {epsilon}") + + min_tokens_to_keep = int(min_tokens_to_keep) + if min_tokens_to_keep < 1: + raise ValueError( + f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}" + ) + + self.epsilon = epsilon + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + # Determine which indices to remove + probabilities = scores.softmax(dim=-1) + indices_to_remove = probabilities < self.epsilon + + # Keep the words with the 'min_tokens_to_keep'-highest probabilities + top_k = min(self.min_tokens_to_keep, scores.size(-1)) # Safety check + indices_to_remove = indices_to_remove & (scores < torch.topk(scores, top_k)[0][..., -1, None]) + + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores + + +class EtaLogitsWarper(LogitsWarper): + r""" + [`LogitsWarper`] that performs eta-sampling, i.e. calculates a dynamic cutoff `eta := min(epsilon, sqrt(epsilon, + e^-entropy(probabilities)))` and restricts to tokens with `prob >= eta`. Takes the largest min_tokens_to_keep + tokens if no tokens satisfy this constraint. See [Truncation Sampling as Language Model + Desmoothing](https://arxiv.org/abs/2210.15191) for more information. + + Args: + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Minimum number of tokens that cannot be filtered.""" + + def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + epsilon = float(epsilon) + if epsilon <= 0 or epsilon >= 1: + raise ValueError(f"`eta_cutoff` has to be a float > 0 and < 1, but is {epsilon}") + + min_tokens_to_keep = int(min_tokens_to_keep) + if min_tokens_to_keep < 1: + raise ValueError( + f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}" + ) + + self.epsilon = torch.tensor(epsilon) + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + # Calculate the adaptive cutoff + probabilities = scores.softmax(dim=-1) + entropy = torch.distributions.Categorical(probs=probabilities).entropy() + eta = torch.min(self.epsilon, torch.sqrt(self.epsilon) * torch.exp(-entropy))[..., None] + indices_to_remove = probabilities < eta + + # Keep the words with the 'min_tokens_to_keep'-highest probabilities + top_k = min(self.min_tokens_to_keep, scores.size(-1)) # Safety check + indices_to_remove = indices_to_remove & (scores < torch.topk(scores, top_k)[0][..., -1, None]) + + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores + + def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int): generated_ngrams = [{} for _ in range(num_hypos)] for idx in range(num_hypos): @@ -438,7 +519,6 @@ class NoBadWordsLogitsProcessor(LogitsProcessor): """ def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]): - if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0: raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.") if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids): diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 64f5c3822f..1ce610eabd 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -39,6 +39,8 @@ from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScor from .configuration_utils import GenerationConfig from .logits_process import ( EncoderNoRepeatNGramLogitsProcessor, + EpsilonLogitsWarper, + EtaLogitsWarper, ExponentialDecayLengthPenalty, ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, @@ -750,23 +752,22 @@ class GenerationMixin: # all samplers can be found in `generation_utils_samplers.py` if generation_config.temperature is not None and generation_config.temperature != 1.0: warpers.append(TemperatureLogitsWarper(generation_config.temperature)) + min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1 if generation_config.top_k is not None and generation_config.top_k != 0: - warpers.append( - TopKLogitsWarper( - top_k=generation_config.top_k, min_tokens_to_keep=(2 if generation_config.num_beams > 1 else 1) - ) - ) + warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep)) if generation_config.top_p is not None and generation_config.top_p < 1.0: - warpers.append( - TopPLogitsWarper( - top_p=generation_config.top_p, min_tokens_to_keep=(2 if generation_config.num_beams > 1 else 1) - ) - ) + warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep)) if generation_config.typical_p is not None and generation_config.typical_p < 1.0: warpers.append( - TypicalLogitsWarper( - mass=generation_config.typical_p, min_tokens_to_keep=(2 if generation_config.num_beams > 1 else 1) - ) + TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep) + ) + if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0: + warpers.append( + EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep) + ) + if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0: + warpers.append( + EtaLogitsWarper(epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep) ) # `LogitNormalization` should always be the last logit processor, when present if generation_config.renormalize_logits is True: @@ -1311,7 +1312,6 @@ class GenerationMixin: ) elif is_contrastive_search_gen_mode: - if generation_config.num_return_sequences > 1: raise ValueError( f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" @@ -1716,7 +1716,6 @@ class GenerationMixin: # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values; # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step if model_kwargs.get("past_key_values") is None: - # prepare inputs model_kwargs["use_cache"] = True model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index 5a47884f4a..e81a5c865f 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -28,6 +28,8 @@ if is_torch_available(): from transformers.generation import ( EncoderNoRepeatNGramLogitsProcessor, + EpsilonLogitsWarper, + EtaLogitsWarper, ExponentialDecayLengthPenalty, ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, @@ -288,6 +290,80 @@ class LogitsProcessorTest(unittest.TestCase): # first batch should keep two tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2. self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2]) + def test_epsilon_dist_warper(self): + input_ids = None + vocab_size = 10 + batch_size = 2 + + # create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper) + dist = torch.log( + torch.tensor( + [[0.87, 0.099, 0.001, 0.03], [0.4, 0.299, 0.101, 0.2]], device=torch_device, dtype=torch.float + ) + ) + + epsilon_warp = EpsilonLogitsWarper(0.1) + filtered_dist = torch.exp(epsilon_warp(input_ids, dist)) + + # dist should be filtered to only keep values with proba >= 0.1 + # exp (-inf) => 0 + EXPECTED_FILTERED_DIST = torch.tensor( + [[0.87, 0, 0, 0], [0.4, 0.299, 0.101, 0.2]], device=torch_device, dtype=torch.float + ) + self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)) + + # check edge cases with negative and extreme logits + ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat( + batch_size, 1 + ) - (vocab_size // 2) + + # make ramp_logits more extreme + ramp_logits[1] = ramp_logits[1] * 100.0 + + # make sure at least 2 tokens are kept + epsilon_warp = EpsilonLogitsWarper(5e-2, min_tokens_to_keep=2, filter_value=0.0) + filtered_dist = epsilon_warp(input_ids, ramp_logits) + + # first batch should keep 3 tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2. + self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [3, 2]) + + def test_eta_dist_warper(self): + input_ids = None + vocab_size = 10 + batch_size = 2 + + # create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper) + dist = torch.log( + torch.tensor([[0.0, 0.1, 0.8, 0.1], [0.01, 0.04, 0.9, 0.05]], device=torch_device, dtype=torch.float) + ) + + eta_warp = EtaLogitsWarper(0.0625) + filtered_dist = torch.exp(eta_warp(input_ids, dist)) + + # dist should be filtered to only keep values with proba >= min(0.0625, sqrt(0.0625) * e^-H(p)) + # min(0.0625, 0.1320) is the cutoff for the first row and min(0.0625, 0.1644) is for the second + # where H is the entropy function and p is the probability vector. + # exp (-inf) => 0 + EXPECTED_FILTERED_DIST = torch.tensor( + [[0.0, 0.1, 0.8, 0.1], [0.0, 0.0, 0.9, 0.0]], device=torch_device, dtype=torch.float + ) + self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)) + + # check edge cases with negative and extreme logits + ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat( + batch_size, 1 + ) - (vocab_size // 2) + + # make ramp_logits more extreme + ramp_logits[1] = ramp_logits[1] * 100.0 + + # make sure at least 2 tokens are kept + eta_warp = EtaLogitsWarper(0.1, min_tokens_to_keep=2, filter_value=0.0) + filtered_dist = eta_warp(input_ids, ramp_logits) + + # first batch should keep 2 tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2. + self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2]) + def test_no_repeat_ngram_dist_processor(self): vocab_size = 3 batch_size = 2