Add Epsilon- and Eta-Sampling (#21121)
* Add epsilon- and eta-sampling. Add epsilon- and eta-sampling, following the official code from https://github.com/john-hewitt/truncation-sampling and adapting to be more configurable, as required by Huggingface transformers. * Add unit tests for epsilon- and eta-sampling. * Black: fix code formatting. * Fix docstring spacing. * Clean up newlines. * Fix implementation bugs and their associated tests. * Remove epsilon- and eta-sampling parameters from PretrainedConfig. * Clarify and clean up the documentation. * Remove parameters for PretrainedConfig test.
This commit is contained in:
@@ -43,6 +43,8 @@ else:
|
|||||||
"ConstrainedBeamSearchScorer",
|
"ConstrainedBeamSearchScorer",
|
||||||
]
|
]
|
||||||
_import_structure["logits_process"] = [
|
_import_structure["logits_process"] = [
|
||||||
|
"EpsilonLogitsWarper",
|
||||||
|
"EtaLogitsWarper",
|
||||||
"ForcedBOSTokenLogitsProcessor",
|
"ForcedBOSTokenLogitsProcessor",
|
||||||
"ForcedEOSTokenLogitsProcessor",
|
"ForcedEOSTokenLogitsProcessor",
|
||||||
"HammingDiversityLogitsProcessor",
|
"HammingDiversityLogitsProcessor",
|
||||||
@@ -162,6 +164,8 @@ if TYPE_CHECKING:
|
|||||||
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||||
from .logits_process import (
|
from .logits_process import (
|
||||||
EncoderNoRepeatNGramLogitsProcessor,
|
EncoderNoRepeatNGramLogitsProcessor,
|
||||||
|
EpsilonLogitsWarper,
|
||||||
|
EtaLogitsWarper,
|
||||||
ExponentialDecayLengthPenalty,
|
ExponentialDecayLengthPenalty,
|
||||||
ForcedBOSTokenLogitsProcessor,
|
ForcedBOSTokenLogitsProcessor,
|
||||||
ForcedEOSTokenLogitsProcessor,
|
ForcedEOSTokenLogitsProcessor,
|
||||||
|
|||||||
@@ -109,6 +109,18 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
generated. If set to float < 1, the smallest set of the most locally typical tokens with probabilities that
|
generated. If set to float < 1, the smallest set of the most locally typical tokens with probabilities that
|
||||||
add up to `typical_p` or higher are kept for generation. See [this
|
add up to `typical_p` or higher are kept for generation. See [this
|
||||||
paper](https://arxiv.org/pdf/2202.00666.pdf) for more details.
|
paper](https://arxiv.org/pdf/2202.00666.pdf) for more details.
|
||||||
|
epsilon_cutoff (`float`, *optional*, defaults to 0.0):
|
||||||
|
If set to float strictly between 0 and 1, only tokens with a conditional probability greater than
|
||||||
|
`epsilon_cutoff` will be sampled. In the paper, suggested values range from 3e-4 to 9e-4, depending on the
|
||||||
|
size of the model. See [Truncation Sampling as Language Model
|
||||||
|
Desmoothing](https://arxiv.org/abs/2210.15191) for more details.
|
||||||
|
eta_cutoff (`float`, *optional*, defaults to 0.0):
|
||||||
|
Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to float strictly between
|
||||||
|
0 and 1, a token is only considered if it is greater than either `eta_cutoff` or `sqrt(eta_cutoff) *
|
||||||
|
exp(-entropy(softmax(next_token_logits)))`. The latter term is intuitively the expected next token
|
||||||
|
probability, scaled by `sqrt(eta_cutoff)`. In the paper, suggested values range from 3e-4 to 2e-3,
|
||||||
|
depending on the size of the model. See [Truncation Sampling as Language Model
|
||||||
|
Desmoothing](https://arxiv.org/abs/2210.15191) for more details.
|
||||||
diversity_penalty (`float`, *optional*, defaults to 0.0):
|
diversity_penalty (`float`, *optional*, defaults to 0.0):
|
||||||
This value is subtracted from a beam's score if it generates a token same as any beam from other group at a
|
This value is subtracted from a beam's score if it generates a token same as any beam from other group at a
|
||||||
particular time. Note that `diversity_penalty` is only effective if `group beam search` is enabled.
|
particular time. Note that `diversity_penalty` is only effective if `group beam search` is enabled.
|
||||||
@@ -223,6 +235,8 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
self.top_k = kwargs.pop("top_k", 50)
|
self.top_k = kwargs.pop("top_k", 50)
|
||||||
self.top_p = kwargs.pop("top_p", 1.0)
|
self.top_p = kwargs.pop("top_p", 1.0)
|
||||||
self.typical_p = kwargs.pop("typical_p", 1.0)
|
self.typical_p = kwargs.pop("typical_p", 1.0)
|
||||||
|
self.epsilon_cutoff = kwargs.pop("epsilon_cutoff", 0.0)
|
||||||
|
self.eta_cutoff = kwargs.pop("eta_cutoff", 0.0)
|
||||||
self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
|
self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
|
||||||
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
|
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
|
||||||
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
||||||
|
|||||||
@@ -138,7 +138,6 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: int):
|
def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: int):
|
||||||
|
|
||||||
for arg_name, arg_value in [
|
for arg_name, arg_value in [
|
||||||
("prompt_length_to_skip", prompt_length_to_skip),
|
("prompt_length_to_skip", prompt_length_to_skip),
|
||||||
("min_new_tokens", min_new_tokens),
|
("min_new_tokens", min_new_tokens),
|
||||||
@@ -152,7 +151,6 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
|
|||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
|
||||||
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
|
new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
|
||||||
if new_tokens_length < self.min_new_tokens:
|
if new_tokens_length < self.min_new_tokens:
|
||||||
scores[:, self.eos_token_id] = -float("inf")
|
scores[:, self.eos_token_id] = -float("inf")
|
||||||
@@ -297,7 +295,6 @@ class TypicalLogitsWarper(LogitsWarper):
|
|||||||
self.min_tokens_to_keep = min_tokens_to_keep
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
|
||||||
# calculate entropy
|
# calculate entropy
|
||||||
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
|
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
|
||||||
p = torch.exp(normalized)
|
p = torch.exp(normalized)
|
||||||
@@ -322,6 +319,90 @@ class TypicalLogitsWarper(LogitsWarper):
|
|||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class EpsilonLogitsWarper(LogitsWarper):
|
||||||
|
r"""
|
||||||
|
[`LogitsWarper`] that performs epsilon-sampling, i.e. restricting to tokens with `prob >= epsilon`. Takes the
|
||||||
|
largest min_tokens_to_keep tokens if no tokens satisfy this constraint. See [Truncation Sampling as Language Model
|
||||||
|
Desmoothing](https://arxiv.org/abs/2210.15191) for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
epsilon (`float`):
|
||||||
|
If set to > 0, only the most tokens with probabilities `epsilon` or higher are kept for generation.
|
||||||
|
filter_value (`float`, *optional*, defaults to `-float("Inf")`):
|
||||||
|
All filtered values will be set to this float value.
|
||||||
|
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||||
|
Minimum number of tokens that cannot be filtered.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||||
|
epsilon = float(epsilon)
|
||||||
|
if epsilon <= 0 or epsilon >= 1:
|
||||||
|
raise ValueError(f"`epsilon_cutoff` has to be a float > 0 and < 1, but is {epsilon}")
|
||||||
|
|
||||||
|
min_tokens_to_keep = int(min_tokens_to_keep)
|
||||||
|
if min_tokens_to_keep < 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.filter_value = filter_value
|
||||||
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
# Determine which indices to remove
|
||||||
|
probabilities = scores.softmax(dim=-1)
|
||||||
|
indices_to_remove = probabilities < self.epsilon
|
||||||
|
|
||||||
|
# Keep the words with the 'min_tokens_to_keep'-highest probabilities
|
||||||
|
top_k = min(self.min_tokens_to_keep, scores.size(-1)) # Safety check
|
||||||
|
indices_to_remove = indices_to_remove & (scores < torch.topk(scores, top_k)[0][..., -1, None])
|
||||||
|
|
||||||
|
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class EtaLogitsWarper(LogitsWarper):
|
||||||
|
r"""
|
||||||
|
[`LogitsWarper`] that performs eta-sampling, i.e. calculates a dynamic cutoff `eta := min(epsilon, sqrt(epsilon,
|
||||||
|
e^-entropy(probabilities)))` and restricts to tokens with `prob >= eta`. Takes the largest min_tokens_to_keep
|
||||||
|
tokens if no tokens satisfy this constraint. See [Truncation Sampling as Language Model
|
||||||
|
Desmoothing](https://arxiv.org/abs/2210.15191) for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||||
|
Minimum number of tokens that cannot be filtered."""
|
||||||
|
|
||||||
|
def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||||
|
epsilon = float(epsilon)
|
||||||
|
if epsilon <= 0 or epsilon >= 1:
|
||||||
|
raise ValueError(f"`eta_cutoff` has to be a float > 0 and < 1, but is {epsilon}")
|
||||||
|
|
||||||
|
min_tokens_to_keep = int(min_tokens_to_keep)
|
||||||
|
if min_tokens_to_keep < 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.epsilon = torch.tensor(epsilon)
|
||||||
|
self.filter_value = filter_value
|
||||||
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
# Calculate the adaptive cutoff
|
||||||
|
probabilities = scores.softmax(dim=-1)
|
||||||
|
entropy = torch.distributions.Categorical(probs=probabilities).entropy()
|
||||||
|
eta = torch.min(self.epsilon, torch.sqrt(self.epsilon) * torch.exp(-entropy))[..., None]
|
||||||
|
indices_to_remove = probabilities < eta
|
||||||
|
|
||||||
|
# Keep the words with the 'min_tokens_to_keep'-highest probabilities
|
||||||
|
top_k = min(self.min_tokens_to_keep, scores.size(-1)) # Safety check
|
||||||
|
indices_to_remove = indices_to_remove & (scores < torch.topk(scores, top_k)[0][..., -1, None])
|
||||||
|
|
||||||
|
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
|
def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
|
||||||
generated_ngrams = [{} for _ in range(num_hypos)]
|
generated_ngrams = [{} for _ in range(num_hypos)]
|
||||||
for idx in range(num_hypos):
|
for idx in range(num_hypos):
|
||||||
@@ -438,7 +519,6 @@ class NoBadWordsLogitsProcessor(LogitsProcessor):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]):
|
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]):
|
||||||
|
|
||||||
if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
|
if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
|
||||||
raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.")
|
raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.")
|
||||||
if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
|
if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
|
||||||
|
|||||||
@@ -39,6 +39,8 @@ from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScor
|
|||||||
from .configuration_utils import GenerationConfig
|
from .configuration_utils import GenerationConfig
|
||||||
from .logits_process import (
|
from .logits_process import (
|
||||||
EncoderNoRepeatNGramLogitsProcessor,
|
EncoderNoRepeatNGramLogitsProcessor,
|
||||||
|
EpsilonLogitsWarper,
|
||||||
|
EtaLogitsWarper,
|
||||||
ExponentialDecayLengthPenalty,
|
ExponentialDecayLengthPenalty,
|
||||||
ForcedBOSTokenLogitsProcessor,
|
ForcedBOSTokenLogitsProcessor,
|
||||||
ForcedEOSTokenLogitsProcessor,
|
ForcedEOSTokenLogitsProcessor,
|
||||||
@@ -750,23 +752,22 @@ class GenerationMixin:
|
|||||||
# all samplers can be found in `generation_utils_samplers.py`
|
# all samplers can be found in `generation_utils_samplers.py`
|
||||||
if generation_config.temperature is not None and generation_config.temperature != 1.0:
|
if generation_config.temperature is not None and generation_config.temperature != 1.0:
|
||||||
warpers.append(TemperatureLogitsWarper(generation_config.temperature))
|
warpers.append(TemperatureLogitsWarper(generation_config.temperature))
|
||||||
|
min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
|
||||||
if generation_config.top_k is not None and generation_config.top_k != 0:
|
if generation_config.top_k is not None and generation_config.top_k != 0:
|
||||||
warpers.append(
|
warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
|
||||||
TopKLogitsWarper(
|
|
||||||
top_k=generation_config.top_k, min_tokens_to_keep=(2 if generation_config.num_beams > 1 else 1)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if generation_config.top_p is not None and generation_config.top_p < 1.0:
|
if generation_config.top_p is not None and generation_config.top_p < 1.0:
|
||||||
warpers.append(
|
warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))
|
||||||
TopPLogitsWarper(
|
|
||||||
top_p=generation_config.top_p, min_tokens_to_keep=(2 if generation_config.num_beams > 1 else 1)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
|
if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
|
||||||
warpers.append(
|
warpers.append(
|
||||||
TypicalLogitsWarper(
|
TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
|
||||||
mass=generation_config.typical_p, min_tokens_to_keep=(2 if generation_config.num_beams > 1 else 1)
|
)
|
||||||
)
|
if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
|
||||||
|
warpers.append(
|
||||||
|
EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep)
|
||||||
|
)
|
||||||
|
if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
|
||||||
|
warpers.append(
|
||||||
|
EtaLogitsWarper(epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep)
|
||||||
)
|
)
|
||||||
# `LogitNormalization` should always be the last logit processor, when present
|
# `LogitNormalization` should always be the last logit processor, when present
|
||||||
if generation_config.renormalize_logits is True:
|
if generation_config.renormalize_logits is True:
|
||||||
@@ -1311,7 +1312,6 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif is_contrastive_search_gen_mode:
|
elif is_contrastive_search_gen_mode:
|
||||||
|
|
||||||
if generation_config.num_return_sequences > 1:
|
if generation_config.num_return_sequences > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
|
f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
|
||||||
@@ -1716,7 +1716,6 @@ class GenerationMixin:
|
|||||||
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
|
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
|
||||||
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
|
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
|
||||||
if model_kwargs.get("past_key_values") is None:
|
if model_kwargs.get("past_key_values") is None:
|
||||||
|
|
||||||
# prepare inputs
|
# prepare inputs
|
||||||
model_kwargs["use_cache"] = True
|
model_kwargs["use_cache"] = True
|
||||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ if is_torch_available():
|
|||||||
|
|
||||||
from transformers.generation import (
|
from transformers.generation import (
|
||||||
EncoderNoRepeatNGramLogitsProcessor,
|
EncoderNoRepeatNGramLogitsProcessor,
|
||||||
|
EpsilonLogitsWarper,
|
||||||
|
EtaLogitsWarper,
|
||||||
ExponentialDecayLengthPenalty,
|
ExponentialDecayLengthPenalty,
|
||||||
ForcedBOSTokenLogitsProcessor,
|
ForcedBOSTokenLogitsProcessor,
|
||||||
ForcedEOSTokenLogitsProcessor,
|
ForcedEOSTokenLogitsProcessor,
|
||||||
@@ -288,6 +290,80 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
# first batch should keep two tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
|
# first batch should keep two tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
|
||||||
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2])
|
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2])
|
||||||
|
|
||||||
|
def test_epsilon_dist_warper(self):
|
||||||
|
input_ids = None
|
||||||
|
vocab_size = 10
|
||||||
|
batch_size = 2
|
||||||
|
|
||||||
|
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
|
||||||
|
dist = torch.log(
|
||||||
|
torch.tensor(
|
||||||
|
[[0.87, 0.099, 0.001, 0.03], [0.4, 0.299, 0.101, 0.2]], device=torch_device, dtype=torch.float
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
epsilon_warp = EpsilonLogitsWarper(0.1)
|
||||||
|
filtered_dist = torch.exp(epsilon_warp(input_ids, dist))
|
||||||
|
|
||||||
|
# dist should be filtered to only keep values with proba >= 0.1
|
||||||
|
# exp (-inf) => 0
|
||||||
|
EXPECTED_FILTERED_DIST = torch.tensor(
|
||||||
|
[[0.87, 0, 0, 0], [0.4, 0.299, 0.101, 0.2]], device=torch_device, dtype=torch.float
|
||||||
|
)
|
||||||
|
self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
|
||||||
|
|
||||||
|
# check edge cases with negative and extreme logits
|
||||||
|
ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(
|
||||||
|
batch_size, 1
|
||||||
|
) - (vocab_size // 2)
|
||||||
|
|
||||||
|
# make ramp_logits more extreme
|
||||||
|
ramp_logits[1] = ramp_logits[1] * 100.0
|
||||||
|
|
||||||
|
# make sure at least 2 tokens are kept
|
||||||
|
epsilon_warp = EpsilonLogitsWarper(5e-2, min_tokens_to_keep=2, filter_value=0.0)
|
||||||
|
filtered_dist = epsilon_warp(input_ids, ramp_logits)
|
||||||
|
|
||||||
|
# first batch should keep 3 tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
|
||||||
|
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [3, 2])
|
||||||
|
|
||||||
|
def test_eta_dist_warper(self):
|
||||||
|
input_ids = None
|
||||||
|
vocab_size = 10
|
||||||
|
batch_size = 2
|
||||||
|
|
||||||
|
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
|
||||||
|
dist = torch.log(
|
||||||
|
torch.tensor([[0.0, 0.1, 0.8, 0.1], [0.01, 0.04, 0.9, 0.05]], device=torch_device, dtype=torch.float)
|
||||||
|
)
|
||||||
|
|
||||||
|
eta_warp = EtaLogitsWarper(0.0625)
|
||||||
|
filtered_dist = torch.exp(eta_warp(input_ids, dist))
|
||||||
|
|
||||||
|
# dist should be filtered to only keep values with proba >= min(0.0625, sqrt(0.0625) * e^-H(p))
|
||||||
|
# min(0.0625, 0.1320) is the cutoff for the first row and min(0.0625, 0.1644) is for the second
|
||||||
|
# where H is the entropy function and p is the probability vector.
|
||||||
|
# exp (-inf) => 0
|
||||||
|
EXPECTED_FILTERED_DIST = torch.tensor(
|
||||||
|
[[0.0, 0.1, 0.8, 0.1], [0.0, 0.0, 0.9, 0.0]], device=torch_device, dtype=torch.float
|
||||||
|
)
|
||||||
|
self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
|
||||||
|
|
||||||
|
# check edge cases with negative and extreme logits
|
||||||
|
ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(
|
||||||
|
batch_size, 1
|
||||||
|
) - (vocab_size // 2)
|
||||||
|
|
||||||
|
# make ramp_logits more extreme
|
||||||
|
ramp_logits[1] = ramp_logits[1] * 100.0
|
||||||
|
|
||||||
|
# make sure at least 2 tokens are kept
|
||||||
|
eta_warp = EtaLogitsWarper(0.1, min_tokens_to_keep=2, filter_value=0.0)
|
||||||
|
filtered_dist = eta_warp(input_ids, ramp_logits)
|
||||||
|
|
||||||
|
# first batch should keep 2 tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
|
||||||
|
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2])
|
||||||
|
|
||||||
def test_no_repeat_ngram_dist_processor(self):
|
def test_no_repeat_ngram_dist_processor(self):
|
||||||
vocab_size = 3
|
vocab_size = 3
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
|
|||||||
Reference in New Issue
Block a user