Generate: unify LogitsWarper and LogitsProcessor (#32626)
This commit is contained in:
@@ -158,9 +158,6 @@ generation.
|
|||||||
[[autodoc]] LogitsProcessorList
|
[[autodoc]] LogitsProcessorList
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] LogitsWarper
|
|
||||||
- __call__
|
|
||||||
|
|
||||||
[[autodoc]] MinLengthLogitsProcessor
|
[[autodoc]] MinLengthLogitsProcessor
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
@@ -421,4 +418,3 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
|||||||
|
|
||||||
[[autodoc]] WatermarkDetector
|
[[autodoc]] WatermarkDetector
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
|
|||||||
@@ -157,9 +157,6 @@ generation_output[:2]
|
|||||||
[[autodoc]] LogitsProcessorList
|
[[autodoc]] LogitsProcessorList
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] LogitsWarper
|
|
||||||
- __call__
|
|
||||||
|
|
||||||
[[autodoc]] MinLengthLogitsProcessor
|
[[autodoc]] MinLengthLogitsProcessor
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
|
|||||||
@@ -151,9 +151,6 @@ generation_output[:2]
|
|||||||
[[autodoc]] LogitsProcessorList
|
[[autodoc]] LogitsProcessorList
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] LogitsWarper
|
|
||||||
- __call__
|
|
||||||
|
|
||||||
[[autodoc]] MinLengthLogitsProcessor
|
[[autodoc]] MinLengthLogitsProcessor
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
|
|||||||
@@ -190,9 +190,9 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one
|
triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one
|
||||||
can allow different forms of each word.
|
can allow different forms of each word.
|
||||||
renormalize_logits (`bool`, *optional*, defaults to `False`):
|
renormalize_logits (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to renormalize the logits after applying all the logits processors or warpers (including the custom
|
Whether to renormalize the logits after applying all the logits processors (including the custom
|
||||||
ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits
|
ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the score logits
|
||||||
are normalized but some logit processors or warpers break the normalization.
|
are normalized but some logit processors break the normalization.
|
||||||
constraints (`List[Constraint]`, *optional*):
|
constraints (`List[Constraint]`, *optional*):
|
||||||
Custom constraints that can be added to the generation to ensure that the output will contain the use of
|
Custom constraints that can be added to the generation to ensure that the output will contain the use of
|
||||||
certain tokens as defined by `Constraint` objects, in the most sensible way possible.
|
certain tokens as defined by `Constraint` objects, in the most sensible way possible.
|
||||||
|
|||||||
@@ -55,6 +55,12 @@ class LogitsProcessor:
|
|||||||
class LogitsWarper:
|
class LogitsWarper:
|
||||||
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
|
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
logger.warning_once(
|
||||||
|
"`LogitsWarper` is deprecated and will be removed in v4.48. Your class should inherit `LogitsProcessor` "
|
||||||
|
"instead, which has the same properties and interface."
|
||||||
|
)
|
||||||
|
|
||||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@@ -64,9 +70,9 @@ class LogitsWarper:
|
|||||||
|
|
||||||
class LogitsProcessorList(list):
|
class LogitsProcessorList(list):
|
||||||
"""
|
"""
|
||||||
This class can be used to create a list of [`LogitsProcessor`] or [`LogitsWarper`] to subsequently process a
|
This class can be used to create a list of [`LogitsProcessor`] to subsequently process a `scores` input tensor.
|
||||||
`scores` input tensor. This class inherits from list and adds a specific *__call__* method to apply each
|
This class inherits from list and adds a specific *__call__* method to apply each [`LogitsProcessor`] to the
|
||||||
[`LogitsProcessor`] or [`LogitsWarper`] to the inputs.
|
inputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
||||||
@@ -233,9 +239,9 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
|
|||||||
return scores_processed
|
return scores_processed
|
||||||
|
|
||||||
|
|
||||||
class TemperatureLogitsWarper(LogitsWarper):
|
class TemperatureLogitsWarper(LogitsProcessor):
|
||||||
r"""
|
r"""
|
||||||
[`LogitsWarper`] for temperature (exponential scaling output probability distribution), which effectively means
|
[`LogitsProcessor`] for temperature (exponential scaling output probability distribution), which effectively means
|
||||||
that it can control the randomness of the predicted tokens. Often used together with [`TopPLogitsWarper`] and
|
that it can control the randomness of the predicted tokens. Often used together with [`TopPLogitsWarper`] and
|
||||||
[`TopKLogitsWarper`].
|
[`TopKLogitsWarper`].
|
||||||
|
|
||||||
@@ -408,10 +414,10 @@ class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
|||||||
return scores_processed
|
return scores_processed
|
||||||
|
|
||||||
|
|
||||||
class TopPLogitsWarper(LogitsWarper):
|
class TopPLogitsWarper(LogitsProcessor):
|
||||||
"""
|
"""
|
||||||
[`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. Often
|
[`LogitsProcessor`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
|
||||||
used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`].
|
Often used together with [`TemperatureLogitsWarper`] and [`TopKLogitsWarper`].
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
top_p (`float`):
|
top_p (`float`):
|
||||||
@@ -475,10 +481,10 @@ class TopPLogitsWarper(LogitsWarper):
|
|||||||
return scores_processed
|
return scores_processed
|
||||||
|
|
||||||
|
|
||||||
class TopKLogitsWarper(LogitsWarper):
|
class TopKLogitsWarper(LogitsProcessor):
|
||||||
r"""
|
r"""
|
||||||
[`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. Often used together
|
[`LogitsProcessor`] that performs top-k, i.e. restricting to the k highest probability elements. Often used
|
||||||
with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`].
|
together with [`TemperatureLogitsWarper`] and [`TopPLogitsWarper`].
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
top_k (`int`):
|
top_k (`int`):
|
||||||
@@ -528,9 +534,9 @@ class TopKLogitsWarper(LogitsWarper):
|
|||||||
return scores_processed
|
return scores_processed
|
||||||
|
|
||||||
|
|
||||||
class MinPLogitsWarper(LogitsWarper):
|
class MinPLogitsWarper(LogitsProcessor):
|
||||||
"""
|
"""
|
||||||
[`LogitsWarper`] that performs min-p, i.e. keeps all tokens that are above a minimum probability, scaled by the
|
[`LogitsProcessor`] that performs min-p, i.e. keeps all tokens that are above a minimum probability, scaled by the
|
||||||
probability of the most likely token. As a result, the filter becomes more agressive in the presence of
|
probability of the most likely token. As a result, the filter becomes more agressive in the presence of
|
||||||
high-probability tokens, which is a sign of a confident output that we shouldn't deviate from.
|
high-probability tokens, which is a sign of a confident output that we shouldn't deviate from.
|
||||||
|
|
||||||
@@ -605,11 +611,11 @@ class MinPLogitsWarper(LogitsWarper):
|
|||||||
return scores_processed
|
return scores_processed
|
||||||
|
|
||||||
|
|
||||||
class TypicalLogitsWarper(LogitsWarper):
|
class TypicalLogitsWarper(LogitsProcessor):
|
||||||
r"""
|
r"""
|
||||||
[`LogitsWarper`] that performs typical decoding. Inspired on how humans use language, it prioritizes tokens whose
|
[`LogitsProcessor`] that performs typical decoding. Inspired on how humans use language, it prioritizes tokens
|
||||||
log probability is close to the entropy of the token probability distribution. This means that the most likely
|
whose log probability is close to the entropy of the token probability distribution. This means that the most
|
||||||
tokens may be discarded in the process.
|
likely tokens may be discarded in the process.
|
||||||
|
|
||||||
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.
|
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information.
|
||||||
|
|
||||||
@@ -693,9 +699,9 @@ class TypicalLogitsWarper(LogitsWarper):
|
|||||||
return scores_processed
|
return scores_processed
|
||||||
|
|
||||||
|
|
||||||
class EpsilonLogitsWarper(LogitsWarper):
|
class EpsilonLogitsWarper(LogitsProcessor):
|
||||||
r"""
|
r"""
|
||||||
[`LogitsWarper`] that performs epsilon-sampling, i.e. restricting to tokens with `prob >= epsilon`. Takes the
|
[`LogitsProcessor`] that performs epsilon-sampling, i.e. restricting to tokens with `prob >= epsilon`. Takes the
|
||||||
largest min_tokens_to_keep tokens if no tokens satisfy this constraint. See [Truncation Sampling as Language Model
|
largest min_tokens_to_keep tokens if no tokens satisfy this constraint. See [Truncation Sampling as Language Model
|
||||||
Desmoothing](https://arxiv.org/abs/2210.15191) for more information.
|
Desmoothing](https://arxiv.org/abs/2210.15191) for more information.
|
||||||
|
|
||||||
@@ -762,15 +768,15 @@ class EpsilonLogitsWarper(LogitsWarper):
|
|||||||
return scores_processed
|
return scores_processed
|
||||||
|
|
||||||
|
|
||||||
class EtaLogitsWarper(LogitsWarper):
|
class EtaLogitsWarper(LogitsProcessor):
|
||||||
r"""
|
r"""
|
||||||
[`LogitsWarper`] that performs eta-sampling, a technique to filter out tokens with probabilities below a dynamic
|
[`LogitsProcessor`] that performs eta-sampling, a technique to filter out tokens with probabilities below a dynamic
|
||||||
cutoff value, `eta`, which is calculated based on a combination of the hyperparameter `epsilon` and the entropy of
|
cutoff value, `eta`, which is calculated based on a combination of the hyperparameter `epsilon` and the entropy of
|
||||||
the token probabilities, i.e. `eta := min(epsilon, sqrt(epsilon * e^-entropy(probabilities)))`. Takes the largest
|
the token probabilities, i.e. `eta := min(epsilon, sqrt(epsilon * e^-entropy(probabilities)))`. Takes the largest
|
||||||
min_tokens_to_keep tokens if no tokens satisfy this constraint. It addresses the issue of poor quality in long
|
min_tokens_to_keep tokens if no tokens satisfy this constraint. It addresses the issue of poor quality in long
|
||||||
samples of text generated by neural language models leading to more coherent and fluent text. See [Truncation
|
samples of text generated by neural language models leading to more coherent and fluent text. See [Truncation
|
||||||
Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more information. Note: `do_sample`
|
Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more information. Note: `do_sample`
|
||||||
must be set to `True` for this `LogitsWarper` to work.
|
must be set to `True` for this `LogitsProcessor` to work.
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -1708,9 +1714,9 @@ class ExponentialDecayLengthPenalty(LogitsProcessor):
|
|||||||
return scores_processed
|
return scores_processed
|
||||||
|
|
||||||
|
|
||||||
class LogitNormalization(LogitsProcessor, LogitsWarper):
|
class LogitNormalization(LogitsProcessor):
|
||||||
r"""
|
r"""
|
||||||
[`LogitsWarper`] and [`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize
|
[`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize
|
||||||
the scores during beam search, after applying the logits processors or warpers, since the search algorithm used in
|
the scores during beam search, after applying the logits processors or warpers, since the search algorithm used in
|
||||||
this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that
|
this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that
|
||||||
the scores are normalized when comparing the hypotheses.
|
the scores are normalized when comparing the hypotheses.
|
||||||
|
|||||||
@@ -735,61 +735,6 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
return candidate_generator
|
return candidate_generator
|
||||||
|
|
||||||
def _get_logits_warper(
|
|
||||||
self,
|
|
||||||
generation_config: GenerationConfig,
|
|
||||||
device: str,
|
|
||||||
) -> LogitsProcessorList:
|
|
||||||
"""
|
|
||||||
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
|
|
||||||
used for multinomial sampling.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# instantiate warpers list
|
|
||||||
warpers = LogitsProcessorList()
|
|
||||||
|
|
||||||
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
|
|
||||||
# better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
|
|
||||||
if generation_config.num_beams > 1:
|
|
||||||
if isinstance(generation_config._eos_token_tensor, list):
|
|
||||||
min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
|
|
||||||
elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
|
|
||||||
min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
|
|
||||||
else:
|
|
||||||
min_tokens_to_keep = 2
|
|
||||||
else:
|
|
||||||
min_tokens_to_keep = 1
|
|
||||||
|
|
||||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
|
||||||
# all samplers can be found in `generation_utils_samplers.py`
|
|
||||||
if generation_config.temperature is not None and generation_config.temperature != 1.0:
|
|
||||||
warpers.append(TemperatureLogitsWarper(generation_config.temperature))
|
|
||||||
if generation_config.top_k is not None and generation_config.top_k != 0:
|
|
||||||
warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
|
|
||||||
if generation_config.top_p is not None and generation_config.top_p < 1.0:
|
|
||||||
warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))
|
|
||||||
if generation_config.min_p is not None:
|
|
||||||
# Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
|
|
||||||
warpers.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))
|
|
||||||
if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
|
|
||||||
warpers.append(
|
|
||||||
TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
|
|
||||||
)
|
|
||||||
if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
|
|
||||||
warpers.append(
|
|
||||||
EpsilonLogitsWarper(epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep)
|
|
||||||
)
|
|
||||||
if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
|
|
||||||
warpers.append(
|
|
||||||
EtaLogitsWarper(
|
|
||||||
epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# `LogitNormalization` should always be the last logit processor, when present
|
|
||||||
if generation_config.renormalize_logits is True:
|
|
||||||
warpers.append(LogitNormalization())
|
|
||||||
return warpers
|
|
||||||
|
|
||||||
def _get_logits_processor(
|
def _get_logits_processor(
|
||||||
self,
|
self,
|
||||||
generation_config: GenerationConfig,
|
generation_config: GenerationConfig,
|
||||||
@@ -960,7 +905,58 @@ class GenerationMixin:
|
|||||||
context_width=generation_config.watermarking_config.context_width,
|
context_width=generation_config.watermarking_config.context_width,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO (joao): find a strategy to specify the order of the processors
|
||||||
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
||||||
|
|
||||||
|
# Processors previously known as `LogitsWarpers`, only applied with sampling strategies
|
||||||
|
if generation_config.do_sample:
|
||||||
|
# In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
|
||||||
|
# better score (i.e. keep len(list(generation_config._eos_token_tensor)) + 1)
|
||||||
|
if generation_config.num_beams > 1:
|
||||||
|
if isinstance(generation_config._eos_token_tensor, list):
|
||||||
|
min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
|
||||||
|
elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
|
||||||
|
min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
|
||||||
|
else:
|
||||||
|
min_tokens_to_keep = 2
|
||||||
|
else:
|
||||||
|
min_tokens_to_keep = 1
|
||||||
|
|
||||||
|
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||||
|
# all samplers can be found in `generation_utils_samplers.py`
|
||||||
|
if generation_config.temperature is not None and generation_config.temperature != 1.0:
|
||||||
|
processors.append(TemperatureLogitsWarper(generation_config.temperature))
|
||||||
|
if generation_config.top_k is not None and generation_config.top_k != 0:
|
||||||
|
processors.append(
|
||||||
|
TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep)
|
||||||
|
)
|
||||||
|
if generation_config.top_p is not None and generation_config.top_p < 1.0:
|
||||||
|
processors.append(
|
||||||
|
TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep)
|
||||||
|
)
|
||||||
|
if generation_config.min_p is not None:
|
||||||
|
# Applied after temperature scaling (see https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
|
||||||
|
processors.append(
|
||||||
|
MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep)
|
||||||
|
)
|
||||||
|
if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
|
||||||
|
processors.append(
|
||||||
|
TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
|
||||||
|
)
|
||||||
|
if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
|
||||||
|
processors.append(
|
||||||
|
EpsilonLogitsWarper(
|
||||||
|
epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
|
||||||
|
processors.append(
|
||||||
|
EtaLogitsWarper(
|
||||||
|
epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# `LogitNormalization` should always be the last logit processor, when present
|
# `LogitNormalization` should always be the last logit processor, when present
|
||||||
if generation_config.renormalize_logits is True:
|
if generation_config.renormalize_logits is True:
|
||||||
processors.append(LogitNormalization())
|
processors.append(LogitNormalization())
|
||||||
@@ -1940,22 +1936,11 @@ class GenerationMixin:
|
|||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 12. prepare logits warper (if `do_sample` is `True`)
|
# 12. run assisted generate
|
||||||
prepared_logits_warper = (
|
|
||||||
self._get_logits_warper(
|
|
||||||
generation_config,
|
|
||||||
device=input_ids.device,
|
|
||||||
)
|
|
||||||
if generation_config.do_sample
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# 13. run assisted generate
|
|
||||||
result = self._assisted_decoding(
|
result = self._assisted_decoding(
|
||||||
input_ids,
|
input_ids,
|
||||||
candidate_generator=candidate_generator,
|
candidate_generator=candidate_generator,
|
||||||
logits_processor=prepared_logits_processor,
|
logits_processor=prepared_logits_processor,
|
||||||
logits_warper=prepared_logits_warper,
|
|
||||||
stopping_criteria=prepared_stopping_criteria,
|
stopping_criteria=prepared_stopping_criteria,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
synced_gpus=synced_gpus,
|
synced_gpus=synced_gpus,
|
||||||
@@ -1968,16 +1953,10 @@ class GenerationMixin:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"dola decoding is not supported with stateful models, such as {self.__class__.__name__}"
|
f"dola decoding is not supported with stateful models, such as {self.__class__.__name__}"
|
||||||
)
|
)
|
||||||
prepared_logits_warper = (
|
|
||||||
self._get_logits_warper(generation_config, device=input_ids.device)
|
|
||||||
if generation_config.do_sample
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
result = self._dola_decoding(
|
result = self._dola_decoding(
|
||||||
input_ids,
|
input_ids,
|
||||||
dola_layers=generation_config.dola_layers,
|
dola_layers=generation_config.dola_layers,
|
||||||
logits_processor=prepared_logits_processor,
|
logits_processor=prepared_logits_processor,
|
||||||
logits_warper=prepared_logits_warper,
|
|
||||||
stopping_criteria=prepared_stopping_criteria,
|
stopping_criteria=prepared_stopping_criteria,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
synced_gpus=synced_gpus,
|
synced_gpus=synced_gpus,
|
||||||
@@ -2005,14 +1984,7 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
|
elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
|
||||||
# 11. prepare logits warper
|
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
|
||||||
prepared_logits_warper = (
|
|
||||||
self._get_logits_warper(generation_config, device=input_ids.device)
|
|
||||||
if generation_config.do_sample
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
|
|
||||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
expand_size=generation_config.num_return_sequences,
|
expand_size=generation_config.num_return_sequences,
|
||||||
@@ -2020,11 +1992,10 @@ class GenerationMixin:
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
|
# 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
|
||||||
result = self._sample(
|
result = self._sample(
|
||||||
input_ids,
|
input_ids,
|
||||||
logits_processor=prepared_logits_processor,
|
logits_processor=prepared_logits_processor,
|
||||||
logits_warper=prepared_logits_warper,
|
|
||||||
stopping_criteria=prepared_stopping_criteria,
|
stopping_criteria=prepared_stopping_criteria,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
synced_gpus=synced_gpus,
|
synced_gpus=synced_gpus,
|
||||||
@@ -2033,14 +2004,7 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
|
elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
|
||||||
# 11. prepare logits warper
|
# 11. prepare beam search scorer
|
||||||
prepared_logits_warper = (
|
|
||||||
self._get_logits_warper(generation_config, device=input_ids.device)
|
|
||||||
if generation_config.do_sample
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# 12. prepare beam search scorer
|
|
||||||
beam_scorer = BeamSearchScorer(
|
beam_scorer = BeamSearchScorer(
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_beams=generation_config.num_beams,
|
num_beams=generation_config.num_beams,
|
||||||
@@ -2051,7 +2015,7 @@ class GenerationMixin:
|
|||||||
max_length=generation_config.max_length,
|
max_length=generation_config.max_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 13. interleave input_ids with `num_beams` additional sequences per batch
|
# 12. interleave input_ids with `num_beams` additional sequences per batch
|
||||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
expand_size=generation_config.num_beams,
|
expand_size=generation_config.num_beams,
|
||||||
@@ -2059,12 +2023,11 @@ class GenerationMixin:
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 14. run beam sample
|
# 13. run beam sample
|
||||||
result = self._beam_search(
|
result = self._beam_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
beam_scorer,
|
beam_scorer,
|
||||||
logits_processor=prepared_logits_processor,
|
logits_processor=prepared_logits_processor,
|
||||||
logits_warper=prepared_logits_warper,
|
|
||||||
stopping_criteria=prepared_stopping_criteria,
|
stopping_criteria=prepared_stopping_criteria,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
synced_gpus=synced_gpus,
|
synced_gpus=synced_gpus,
|
||||||
@@ -2287,7 +2250,6 @@ class GenerationMixin:
|
|||||||
generation_config: GenerationConfig,
|
generation_config: GenerationConfig,
|
||||||
synced_gpus: bool,
|
synced_gpus: bool,
|
||||||
streamer: "BaseStreamer",
|
streamer: "BaseStreamer",
|
||||||
logits_warper: Optional[LogitsProcessorList],
|
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
||||||
r"""
|
r"""
|
||||||
@@ -2316,10 +2278,6 @@ class GenerationMixin:
|
|||||||
streamer (`BaseStreamer`, *optional*):
|
streamer (`BaseStreamer`, *optional*):
|
||||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||||
logits_warper (`LogitsProcessorList`, *optional*):
|
|
||||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
|
|
||||||
to warp the prediction score distribution of the language modeling head applied before multinomial
|
|
||||||
sampling at each generation step.
|
|
||||||
model_kwargs:
|
model_kwargs:
|
||||||
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
|
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
|
||||||
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
|
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||||
@@ -2344,11 +2302,6 @@ class GenerationMixin:
|
|||||||
return_dict_in_generate = generation_config.return_dict_in_generate
|
return_dict_in_generate = generation_config.return_dict_in_generate
|
||||||
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
|
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
|
||||||
do_sample = generation_config.do_sample
|
do_sample = generation_config.do_sample
|
||||||
if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
|
|
||||||
raise ValueError(
|
|
||||||
"`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
|
|
||||||
f"{logits_warper})."
|
|
||||||
)
|
|
||||||
|
|
||||||
# init attention / hidden states / scores tuples
|
# init attention / hidden states / scores tuples
|
||||||
scores = () if (return_dict_in_generate and output_scores) else None
|
scores = () if (return_dict_in_generate and output_scores) else None
|
||||||
@@ -2436,8 +2389,7 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
# pre-process distribution
|
# pre-process distribution
|
||||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||||
if do_sample: # sample
|
|
||||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
|
||||||
# Store scores, attentions and hidden_states when required
|
# Store scores, attentions and hidden_states when required
|
||||||
if return_dict_in_generate:
|
if return_dict_in_generate:
|
||||||
if output_scores:
|
if output_scores:
|
||||||
@@ -2893,7 +2845,6 @@ class GenerationMixin:
|
|||||||
generation_config: GenerationConfig,
|
generation_config: GenerationConfig,
|
||||||
synced_gpus: bool,
|
synced_gpus: bool,
|
||||||
streamer: Optional["BaseStreamer"],
|
streamer: Optional["BaseStreamer"],
|
||||||
logits_warper: Optional[LogitsProcessorList],
|
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
||||||
r"""
|
r"""
|
||||||
@@ -2916,11 +2867,6 @@ class GenerationMixin:
|
|||||||
streamer (`BaseStreamer`, *optional*):
|
streamer (`BaseStreamer`, *optional*):
|
||||||
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
||||||
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
||||||
logits_warper (`LogitsProcessorList`, *optional*):
|
|
||||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
|
|
||||||
to warp the prediction score distribution of the language modeling head applied before multinomial
|
|
||||||
sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in
|
|
||||||
`generation_config`)
|
|
||||||
model_kwargs:
|
model_kwargs:
|
||||||
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
||||||
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||||
@@ -2942,11 +2888,6 @@ class GenerationMixin:
|
|||||||
max_length = generation_config.max_length
|
max_length = generation_config.max_length
|
||||||
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
|
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
|
||||||
do_sample = generation_config.do_sample
|
do_sample = generation_config.do_sample
|
||||||
if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
|
|
||||||
raise ValueError(
|
|
||||||
"`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
|
|
||||||
f"{logits_warper})."
|
|
||||||
)
|
|
||||||
|
|
||||||
# init attention / hidden states / scores tuples
|
# init attention / hidden states / scores tuples
|
||||||
scores = () if (return_dict_in_generate and output_scores) else None
|
scores = () if (return_dict_in_generate and output_scores) else None
|
||||||
@@ -2990,8 +2931,6 @@ class GenerationMixin:
|
|||||||
|
|
||||||
# pre-process distribution
|
# pre-process distribution
|
||||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||||
if do_sample:
|
|
||||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
|
||||||
|
|
||||||
# Store scores, attentions and hidden_states when required
|
# Store scores, attentions and hidden_states when required
|
||||||
if return_dict_in_generate:
|
if return_dict_in_generate:
|
||||||
@@ -3105,7 +3044,6 @@ class GenerationMixin:
|
|||||||
stopping_criteria: StoppingCriteriaList,
|
stopping_criteria: StoppingCriteriaList,
|
||||||
generation_config: GenerationConfig,
|
generation_config: GenerationConfig,
|
||||||
synced_gpus: bool,
|
synced_gpus: bool,
|
||||||
logits_warper: Optional[LogitsProcessorList],
|
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
) -> Union[GenerateBeamOutput, torch.LongTensor]:
|
) -> Union[GenerateBeamOutput, torch.LongTensor]:
|
||||||
r"""
|
r"""
|
||||||
@@ -3128,11 +3066,6 @@ class GenerationMixin:
|
|||||||
The generation configuration to be used as parametrization of the decoding method.
|
The generation configuration to be used as parametrization of the decoding method.
|
||||||
synced_gpus (`bool`):
|
synced_gpus (`bool`):
|
||||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
||||||
logits_warper (`LogitsProcessorList`, *optional*):
|
|
||||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
|
|
||||||
to warp the prediction score distribution of the language modeling head applied before multinomial
|
|
||||||
sampling at each generation step. Only required with sampling strategies (i.e. `do_sample` is set in
|
|
||||||
`generation_config`)
|
|
||||||
model_kwargs:
|
model_kwargs:
|
||||||
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
||||||
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
||||||
@@ -3154,11 +3087,6 @@ class GenerationMixin:
|
|||||||
return_dict_in_generate = generation_config.return_dict_in_generate
|
return_dict_in_generate = generation_config.return_dict_in_generate
|
||||||
sequential = generation_config.low_memory
|
sequential = generation_config.low_memory
|
||||||
do_sample = generation_config.do_sample
|
do_sample = generation_config.do_sample
|
||||||
if do_sample is True and not isinstance(logits_warper, LogitsProcessorList):
|
|
||||||
raise ValueError(
|
|
||||||
"`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is "
|
|
||||||
f"{logits_warper})."
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_size = len(beam_scorer._beam_hyps)
|
batch_size = len(beam_scorer._beam_hyps)
|
||||||
num_beams = beam_scorer.num_beams
|
num_beams = beam_scorer.num_beams
|
||||||
@@ -3249,8 +3177,6 @@ class GenerationMixin:
|
|||||||
) # (batch_size * num_beams, vocab_size)
|
) # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
|
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
|
||||||
if do_sample:
|
|
||||||
next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed)
|
|
||||||
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
|
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
|
||||||
next_token_scores_processed
|
next_token_scores_processed
|
||||||
)
|
)
|
||||||
@@ -3698,10 +3624,6 @@ class GenerationMixin:
|
|||||||
stopping_criteria (`StoppingCriteriaList`):
|
stopping_criteria (`StoppingCriteriaList`):
|
||||||
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
||||||
used to tell if the generation loop should stop.
|
used to tell if the generation loop should stop.
|
||||||
logits_warper (`LogitsProcessorList`):
|
|
||||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
|
|
||||||
to warp the prediction score distribution of the language modeling head applied before multinomial
|
|
||||||
sampling at each generation step.
|
|
||||||
generation_config ([`~generation.GenerationConfig`]):
|
generation_config ([`~generation.GenerationConfig`]):
|
||||||
The generation configuration to be used as parametrization of the decoding method.
|
The generation configuration to be used as parametrization of the decoding method.
|
||||||
synced_gpus (`bool`):
|
synced_gpus (`bool`):
|
||||||
@@ -3915,7 +3837,6 @@ class GenerationMixin:
|
|||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
candidate_generator: CandidateGenerator,
|
candidate_generator: CandidateGenerator,
|
||||||
logits_processor: LogitsProcessorList,
|
logits_processor: LogitsProcessorList,
|
||||||
logits_warper: LogitsProcessorList,
|
|
||||||
stopping_criteria: StoppingCriteriaList,
|
stopping_criteria: StoppingCriteriaList,
|
||||||
generation_config: GenerationConfig,
|
generation_config: GenerationConfig,
|
||||||
synced_gpus: bool,
|
synced_gpus: bool,
|
||||||
@@ -3937,10 +3858,6 @@ class GenerationMixin:
|
|||||||
logits_processor (`LogitsProcessorList`):
|
logits_processor (`LogitsProcessorList`):
|
||||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
||||||
used to modify the prediction scores of the language modeling head applied at each generation step.
|
used to modify the prediction scores of the language modeling head applied at each generation step.
|
||||||
logits_warper (`LogitsProcessorList`):
|
|
||||||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
|
|
||||||
to warp the prediction score distribution of the language modeling head applied before multinomial
|
|
||||||
sampling at each generation step. Only used if sampling is active.
|
|
||||||
stopping_criteria (`StoppingCriteriaList`):
|
stopping_criteria (`StoppingCriteriaList`):
|
||||||
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
||||||
used to tell if the generation loop should stop.
|
used to tell if the generation loop should stop.
|
||||||
@@ -3963,7 +3880,7 @@ class GenerationMixin:
|
|||||||
`model.config.is_encoder_decoder=True`.
|
`model.config.is_encoder_decoder=True`.
|
||||||
"""
|
"""
|
||||||
# init values
|
# init values
|
||||||
do_sample = logits_warper is not None
|
do_sample = generation_config.do_sample
|
||||||
output_attentions = generation_config.output_attentions
|
output_attentions = generation_config.output_attentions
|
||||||
output_hidden_states = generation_config.output_hidden_states
|
output_hidden_states = generation_config.output_hidden_states
|
||||||
output_scores = generation_config.output_scores
|
output_scores = generation_config.output_scores
|
||||||
@@ -4047,9 +3964,6 @@ class GenerationMixin:
|
|||||||
if len(logits_processor) > 0:
|
if len(logits_processor) > 0:
|
||||||
for i in range(candidate_length + 1):
|
for i in range(candidate_length + 1):
|
||||||
new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
|
new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
|
||||||
if do_sample and len(logits_warper) > 0:
|
|
||||||
for i in range(candidate_length + 1):
|
|
||||||
new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
|
|
||||||
|
|
||||||
# 3. Select the accepted tokens. There are two possible cases:
|
# 3. Select the accepted tokens. There are two possible cases:
|
||||||
# Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding)
|
# Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding)
|
||||||
|
|||||||
@@ -56,9 +56,9 @@ class BarkSemanticGenerationConfig(GenerationConfig):
|
|||||||
eos_token_id (`int`, *optional*, defaults to 10_000):
|
eos_token_id (`int`, *optional*, defaults to 10_000):
|
||||||
The id of the *end-of-sequence* token.
|
The id of the *end-of-sequence* token.
|
||||||
renormalize_logits (`bool`, *optional*, defaults to `True`):
|
renormalize_logits (`bool`, *optional*, defaults to `True`):
|
||||||
Whether to renormalize the logits after applying all the logits processors or warpers (including the
|
Whether to renormalize the logits after applying all the logits processors (including the
|
||||||
custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the
|
custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the
|
||||||
score logits are normalized but some logit processors or warpers break the normalization.
|
score logits are normalized but some logit processors break the normalization.
|
||||||
max_new_tokens (`int`, *optional*, defaults to 768):
|
max_new_tokens (`int`, *optional*, defaults to 768):
|
||||||
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
|
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
|
||||||
output_scores (`bool`, *optional*, defaults to `False`):
|
output_scores (`bool`, *optional*, defaults to `False`):
|
||||||
@@ -143,9 +143,9 @@ class BarkCoarseGenerationConfig(GenerationConfig):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
renormalize_logits (`bool`, *optional*, defaults to `True`):
|
renormalize_logits (`bool`, *optional*, defaults to `True`):
|
||||||
Whether to renormalize the logits after applying all the logits processors or warpers (including the
|
Whether to renormalize the logits after applying all the logits processors (including the
|
||||||
custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the
|
custom ones). It's highly recommended to set this flag to `True` as the search algorithms suppose the
|
||||||
score logits are normalized but some logit processors or warpers break the normalization.
|
score logits are normalized but some logit processors break the normalization.
|
||||||
output_scores (`bool`, *optional*, defaults to `False`):
|
output_scores (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
|
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
|
||||||
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
|
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
|
||||||
|
|||||||
@@ -1609,13 +1609,6 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
|
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
|
||||||
# 11. prepare logits warper
|
|
||||||
prepared_logits_warper = (
|
|
||||||
self._get_logits_warper(generation_config, device=input_ids.device)
|
|
||||||
if generation_config.do_sample
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# expand input_ids with `num_return_sequences` additional sequences per batch
|
# expand input_ids with `num_return_sequences` additional sequences per batch
|
||||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@@ -1623,11 +1616,10 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 12. run sample
|
# 11. run sample
|
||||||
outputs = self._sample(
|
outputs = self._sample(
|
||||||
input_ids,
|
input_ids,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
logits_warper=prepared_logits_warper,
|
|
||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=stopping_criteria,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
synced_gpus=synced_gpus,
|
synced_gpus=synced_gpus,
|
||||||
@@ -2649,13 +2641,6 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
|
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
|
||||||
# 11. prepare logits warper
|
|
||||||
prepared_logits_warper = (
|
|
||||||
self._get_logits_warper(generation_config, device=input_ids.device)
|
|
||||||
if generation_config.do_sample
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# expand input_ids with `num_return_sequences` additional sequences per batch
|
# expand input_ids with `num_return_sequences` additional sequences per batch
|
||||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@@ -2664,11 +2649,10 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 12. run sample
|
# 11. run sample
|
||||||
outputs = self._sample(
|
outputs = self._sample(
|
||||||
input_ids,
|
input_ids,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
logits_warper=prepared_logits_warper,
|
|
||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=stopping_criteria,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
synced_gpus=synced_gpus,
|
synced_gpus=synced_gpus,
|
||||||
|
|||||||
@@ -1531,13 +1531,6 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
|
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
|
||||||
# 11. prepare logits warper
|
|
||||||
prepared_logits_warper = (
|
|
||||||
self._get_logits_warper(generation_config, device=input_ids.device)
|
|
||||||
if generation_config.do_sample
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# expand input_ids with `num_return_sequences` additional sequences per batch
|
# expand input_ids with `num_return_sequences` additional sequences per batch
|
||||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@@ -1545,11 +1538,10 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 12. run sample
|
# 11. run sample
|
||||||
outputs = self._sample(
|
outputs = self._sample(
|
||||||
input_ids,
|
input_ids,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
logits_warper=prepared_logits_warper,
|
|
||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=stopping_criteria,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
synced_gpus=synced_gpus,
|
synced_gpus=synced_gpus,
|
||||||
@@ -2490,13 +2482,6 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
|
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
|
||||||
# 11. prepare logits warper
|
|
||||||
prepared_logits_warper = (
|
|
||||||
self._get_logits_warper(generation_config, device=input_ids.device)
|
|
||||||
if generation_config.do_sample
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# expand input_ids with `num_return_sequences` additional sequences per batch
|
# expand input_ids with `num_return_sequences` additional sequences per batch
|
||||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@@ -2505,11 +2490,10 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 12. run sample
|
# 11. run sample
|
||||||
outputs = self._sample(
|
outputs = self._sample(
|
||||||
input_ids,
|
input_ids,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
logits_warper=prepared_logits_warper,
|
|
||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=stopping_criteria,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
synced_gpus=synced_gpus,
|
synced_gpus=synced_gpus,
|
||||||
|
|||||||
@@ -1558,7 +1558,6 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
synced_gpus=False,
|
synced_gpus=False,
|
||||||
streamer=None,
|
streamer=None,
|
||||||
logits_warper=None,
|
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
elif generation_config.num_beams > 1:
|
elif generation_config.num_beams > 1:
|
||||||
@@ -1580,7 +1579,6 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
stopping_criteria=prepared_stopping_criteria,
|
stopping_criteria=prepared_stopping_criteria,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
synced_gpus=False,
|
synced_gpus=False,
|
||||||
logits_warper=None,
|
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -118,26 +118,24 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
return config, input_ids, attention_mask
|
return config, input_ids, attention_mask
|
||||||
|
|
||||||
@staticmethod
|
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||||
def _get_logits_processor_and_warper_kwargs(
|
logits_processor_kwargs = {
|
||||||
input_length,
|
|
||||||
forced_bos_token_id=None,
|
|
||||||
forced_eos_token_id=None,
|
|
||||||
):
|
|
||||||
process_kwargs = {
|
|
||||||
"bad_words_ids": [[1, 0]],
|
"bad_words_ids": [[1, 0]],
|
||||||
"repetition_penalty": 1.2,
|
"repetition_penalty": 1.2,
|
||||||
"remove_invalid_values": True,
|
"remove_invalid_values": True,
|
||||||
}
|
}
|
||||||
# NoRepeatNGramLogitsProcessor + forced tokens may result in no valid continuations
|
if do_sample:
|
||||||
if forced_bos_token_id is None and forced_eos_token_id is None:
|
logits_processor_kwargs.update(
|
||||||
process_kwargs["no_repeat_ngram_size"] = 2
|
{
|
||||||
|
"top_k": 10,
|
||||||
|
"top_p": 0.7,
|
||||||
|
"temperature": 0.7,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
warp_kwargs = {"top_k": 10, "top_p": 0.7, "temperature": 0.7}
|
return logits_processor_kwargs
|
||||||
return process_kwargs, warp_kwargs
|
|
||||||
|
|
||||||
@staticmethod
|
def _get_beam_kwargs(self, num_return_sequences=1):
|
||||||
def _get_beam_kwargs(num_return_sequences=1):
|
|
||||||
beam_kwargs = {
|
beam_kwargs = {
|
||||||
"early_stopping": False,
|
"early_stopping": False,
|
||||||
"length_penalty": 2.0,
|
"length_penalty": 2.0,
|
||||||
@@ -146,8 +144,7 @@ class GenerationTesterMixin:
|
|||||||
}
|
}
|
||||||
return beam_kwargs
|
return beam_kwargs
|
||||||
|
|
||||||
@staticmethod
|
def _get_diverse_beam_kwargs(self, num_return_sequences=1):
|
||||||
def _get_diverse_beam_kwargs(num_return_sequences=1):
|
|
||||||
beam_kwargs = {
|
beam_kwargs = {
|
||||||
"early_stopping": False,
|
"early_stopping": False,
|
||||||
"length_penalty": 2.0,
|
"length_penalty": 2.0,
|
||||||
@@ -158,8 +155,7 @@ class GenerationTesterMixin:
|
|||||||
}
|
}
|
||||||
return beam_kwargs
|
return beam_kwargs
|
||||||
|
|
||||||
@staticmethod
|
def _get_constrained_beam_kwargs(self, num_return_sequences=1):
|
||||||
def _get_constrained_beam_kwargs(num_return_sequences=1):
|
|
||||||
beam_kwargs = {
|
beam_kwargs = {
|
||||||
"early_stopping": False,
|
"early_stopping": False,
|
||||||
"length_penalty": 2.0,
|
"length_penalty": 2.0,
|
||||||
@@ -199,12 +195,7 @@ class GenerationTesterMixin:
|
|||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict_in_generate=False,
|
return_dict_in_generate=False,
|
||||||
):
|
):
|
||||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||||
input_ids.shape[-1],
|
|
||||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
|
||||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||||
output_generate = model.generate(
|
output_generate = model.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
@@ -216,7 +207,7 @@ class GenerationTesterMixin:
|
|||||||
output_scores=output_scores,
|
output_scores=output_scores,
|
||||||
output_logits=output_logits,
|
output_logits=output_logits,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
**logits_process_kwargs,
|
**logits_processor_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -228,8 +219,6 @@ class GenerationTesterMixin:
|
|||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
num_return_sequences,
|
num_return_sequences,
|
||||||
logits_warper_kwargs,
|
|
||||||
process_kwargs,
|
|
||||||
output_scores=False,
|
output_scores=False,
|
||||||
output_logits=False,
|
output_logits=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
@@ -237,6 +226,7 @@ class GenerationTesterMixin:
|
|||||||
return_dict_in_generate=False,
|
return_dict_in_generate=False,
|
||||||
):
|
):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
|
||||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||||
output_generate = model.generate(
|
output_generate = model.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
@@ -249,8 +239,7 @@ class GenerationTesterMixin:
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
**logits_warper_kwargs,
|
**logits_processor_kwargs,
|
||||||
**process_kwargs,
|
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -262,13 +251,13 @@ class GenerationTesterMixin:
|
|||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
beam_kwargs,
|
beam_kwargs,
|
||||||
logits_process_kwargs,
|
|
||||||
output_scores=False,
|
output_scores=False,
|
||||||
output_logits=False,
|
output_logits=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict_in_generate=False,
|
return_dict_in_generate=False,
|
||||||
):
|
):
|
||||||
|
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||||
output_generate = model.generate(
|
output_generate = model.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
@@ -280,7 +269,7 @@ class GenerationTesterMixin:
|
|||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
**beam_kwargs,
|
**beam_kwargs,
|
||||||
**logits_process_kwargs,
|
**logits_processor_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -292,7 +281,6 @@ class GenerationTesterMixin:
|
|||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
beam_kwargs,
|
beam_kwargs,
|
||||||
logits_warper_kwargs,
|
|
||||||
output_scores=False,
|
output_scores=False,
|
||||||
output_logits=False,
|
output_logits=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
@@ -300,6 +288,7 @@ class GenerationTesterMixin:
|
|||||||
return_dict_in_generate=False,
|
return_dict_in_generate=False,
|
||||||
):
|
):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
|
||||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||||
output_generate = model.generate(
|
output_generate = model.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
@@ -311,7 +300,7 @@ class GenerationTesterMixin:
|
|||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
**beam_kwargs,
|
**beam_kwargs,
|
||||||
**logits_warper_kwargs,
|
**logits_processor_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -323,13 +312,13 @@ class GenerationTesterMixin:
|
|||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
beam_kwargs,
|
beam_kwargs,
|
||||||
logits_process_kwargs,
|
|
||||||
output_scores=False,
|
output_scores=False,
|
||||||
output_logits=False,
|
output_logits=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict_in_generate=False,
|
return_dict_in_generate=False,
|
||||||
):
|
):
|
||||||
|
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||||
output_generate = model.generate(
|
output_generate = model.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
@@ -341,7 +330,7 @@ class GenerationTesterMixin:
|
|||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
**beam_kwargs,
|
**beam_kwargs,
|
||||||
**logits_process_kwargs,
|
**logits_processor_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -354,13 +343,13 @@ class GenerationTesterMixin:
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
constraints,
|
constraints,
|
||||||
beam_kwargs,
|
beam_kwargs,
|
||||||
logits_process_kwargs,
|
|
||||||
output_scores=False,
|
output_scores=False,
|
||||||
output_logits=False,
|
output_logits=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict_in_generate=False,
|
return_dict_in_generate=False,
|
||||||
):
|
):
|
||||||
|
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||||
output_generate = model.generate(
|
output_generate = model.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
@@ -373,7 +362,7 @@ class GenerationTesterMixin:
|
|||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
constraints=constraints,
|
constraints=constraints,
|
||||||
**beam_kwargs,
|
**beam_kwargs,
|
||||||
**logits_process_kwargs,
|
**logits_processor_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -395,12 +384,7 @@ class GenerationTesterMixin:
|
|||||||
"top_k": 5,
|
"top_k": 5,
|
||||||
}
|
}
|
||||||
|
|
||||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||||
input_ids.shape[-1],
|
|
||||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
|
||||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||||
output_generate = model.generate(
|
output_generate = model.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
@@ -412,7 +396,7 @@ class GenerationTesterMixin:
|
|||||||
output_scores=output_scores,
|
output_scores=output_scores,
|
||||||
output_logits=output_logits,
|
output_logits=output_logits,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
**logits_process_kwargs,
|
**logits_processor_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
**contrastive_search_kwargs,
|
**contrastive_search_kwargs,
|
||||||
)
|
)
|
||||||
@@ -495,19 +479,11 @@ class GenerationTesterMixin:
|
|||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
|
|
||||||
input_ids.shape[-1],
|
|
||||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
|
||||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
output_generate = self._sample_generate(
|
output_generate = self._sample_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
num_return_sequences=1,
|
num_return_sequences=1,
|
||||||
logits_warper_kwargs=logits_warper_kwargs,
|
|
||||||
process_kwargs=process_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
@@ -521,20 +497,11 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
config.use_cache = False
|
config.use_cache = False
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
|
|
||||||
input_ids.shape[-1],
|
|
||||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
|
||||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
output_generate = self._sample_generate(
|
output_generate = self._sample_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
num_return_sequences=2,
|
num_return_sequences=2,
|
||||||
logits_warper_kwargs=logits_warper_kwargs,
|
|
||||||
process_kwargs=process_kwargs,
|
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_logits=True,
|
output_logits=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
@@ -561,19 +528,12 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
|
||||||
input_ids.shape[-1],
|
|
||||||
config.forced_bos_token_id,
|
|
||||||
config.forced_eos_token_id,
|
|
||||||
)
|
|
||||||
beam_kwargs = self._get_beam_kwargs()
|
beam_kwargs = self._get_beam_kwargs()
|
||||||
|
|
||||||
output_generate = self._beam_search_generate(
|
output_generate = self._beam_search_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
logits_process_kwargs=logits_process_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
@@ -589,18 +549,12 @@ class GenerationTesterMixin:
|
|||||||
config.use_cache = False
|
config.use_cache = False
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
|
||||||
input_ids.shape[-1],
|
|
||||||
config.forced_bos_token_id,
|
|
||||||
config.forced_eos_token_id,
|
|
||||||
)
|
|
||||||
beam_kwargs = self._get_beam_kwargs()
|
beam_kwargs = self._get_beam_kwargs()
|
||||||
output_generate = self._beam_search_generate(
|
output_generate = self._beam_search_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
logits_process_kwargs=logits_process_kwargs,
|
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_logits=True,
|
output_logits=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
@@ -633,12 +587,6 @@ class GenerationTesterMixin:
|
|||||||
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
|
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
|
||||||
input_ids.shape[-1],
|
|
||||||
config.forced_bos_token_id,
|
|
||||||
config.forced_eos_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
beam_kwargs = self._get_beam_kwargs()
|
beam_kwargs = self._get_beam_kwargs()
|
||||||
|
|
||||||
config.use_cache = True
|
config.use_cache = True
|
||||||
@@ -649,7 +597,6 @@ class GenerationTesterMixin:
|
|||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
logits_process_kwargs=logits_process_kwargs,
|
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_logits=True,
|
output_logits=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
@@ -693,17 +640,13 @@ class GenerationTesterMixin:
|
|||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
_, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1])
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
beam_kwargs = self._get_beam_kwargs()
|
beam_kwargs = self._get_beam_kwargs()
|
||||||
|
|
||||||
output_generate = self._beam_sample_generate(
|
output_generate = self._beam_sample_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
logits_warper_kwargs=logits_warper_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
@@ -711,7 +654,13 @@ class GenerationTesterMixin:
|
|||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||||
|
|
||||||
if "inputs_embeds" in set(inspect.signature(model.prepare_inputs_for_generation).parameters):
|
prepare_inputs_for_generation_args = set(inspect.signature(model.prepare_inputs_for_generation).parameters)
|
||||||
|
# `inputs_embeds` input is well supported when `cache_positions` is used, because it means the modeling
|
||||||
|
# code is up to date with our most recent standards
|
||||||
|
if (
|
||||||
|
"inputs_embeds" in prepare_inputs_for_generation_args
|
||||||
|
and "cache_positions" in prepare_inputs_for_generation_args
|
||||||
|
):
|
||||||
input_embeds = model.get_input_embeddings()(input_ids)
|
input_embeds = model.get_input_embeddings()(input_ids)
|
||||||
beam_kwargs.update({"inputs_embeds": input_embeds})
|
beam_kwargs.update({"inputs_embeds": input_embeds})
|
||||||
output_generate2 = self._beam_sample_generate(
|
output_generate2 = self._beam_sample_generate(
|
||||||
@@ -719,7 +668,6 @@ class GenerationTesterMixin:
|
|||||||
input_ids=None,
|
input_ids=None,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
logits_warper_kwargs=logits_warper_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2)
|
torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2)
|
||||||
@@ -732,7 +680,6 @@ class GenerationTesterMixin:
|
|||||||
config.use_cache = False
|
config.use_cache = False
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
_, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1])
|
|
||||||
beam_kwargs = self._get_beam_kwargs()
|
beam_kwargs = self._get_beam_kwargs()
|
||||||
|
|
||||||
output_generate = self._beam_sample_generate(
|
output_generate = self._beam_sample_generate(
|
||||||
@@ -740,7 +687,6 @@ class GenerationTesterMixin:
|
|||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
logits_warper_kwargs=logits_warper_kwargs,
|
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_logits=True,
|
output_logits=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
@@ -788,12 +734,6 @@ class GenerationTesterMixin:
|
|||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
|
||||||
input_ids.shape[-1],
|
|
||||||
config.forced_bos_token_id,
|
|
||||||
config.forced_eos_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# check `generate()` and `group_beam_search()` are equal
|
# check `generate()` and `group_beam_search()` are equal
|
||||||
beam_kwargs = self._get_diverse_beam_kwargs()
|
beam_kwargs = self._get_diverse_beam_kwargs()
|
||||||
output_generate = self._group_beam_search_generate(
|
output_generate = self._group_beam_search_generate(
|
||||||
@@ -801,7 +741,6 @@ class GenerationTesterMixin:
|
|||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
logits_process_kwargs=logits_process_kwargs,
|
|
||||||
)
|
)
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||||
@@ -816,7 +755,6 @@ class GenerationTesterMixin:
|
|||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
logits_process_kwargs=logits_process_kwargs,
|
|
||||||
)
|
)
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||||
@@ -829,19 +767,12 @@ class GenerationTesterMixin:
|
|||||||
config.use_cache = False
|
config.use_cache = False
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
|
||||||
input_ids.shape[-1],
|
|
||||||
config.forced_bos_token_id,
|
|
||||||
config.forced_eos_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
beam_kwargs = self._get_diverse_beam_kwargs()
|
beam_kwargs = self._get_diverse_beam_kwargs()
|
||||||
output_generate = self._group_beam_search_generate(
|
output_generate = self._group_beam_search_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
logits_process_kwargs=logits_process_kwargs,
|
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_logits=True,
|
output_logits=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
@@ -871,12 +802,6 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
|
||||||
input_ids.shape[-1],
|
|
||||||
config.forced_bos_token_id,
|
|
||||||
config.forced_eos_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sample constraints
|
# Sample constraints
|
||||||
min_id = 3
|
min_id = 3
|
||||||
max_id = config.vocab_size
|
max_id = config.vocab_size
|
||||||
@@ -893,7 +818,6 @@ class GenerationTesterMixin:
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
constraints=constraints,
|
constraints=constraints,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
logits_process_kwargs=logits_process_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
@@ -919,7 +843,6 @@ class GenerationTesterMixin:
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
constraints=constraints,
|
constraints=constraints,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
logits_process_kwargs=logits_process_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
@@ -938,11 +861,6 @@ class GenerationTesterMixin:
|
|||||||
config.use_cache = False
|
config.use_cache = False
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
|
||||||
input_ids.shape[-1],
|
|
||||||
config.forced_bos_token_id,
|
|
||||||
config.forced_eos_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sample constraints
|
# Sample constraints
|
||||||
min_id = 3
|
min_id = 3
|
||||||
@@ -959,7 +877,6 @@ class GenerationTesterMixin:
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
constraints=constraints,
|
constraints=constraints,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
logits_process_kwargs=logits_process_kwargs,
|
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_logits=True,
|
output_logits=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
|
|||||||
@@ -414,10 +414,6 @@ class BioGptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||||
|
|
||||||
@unittest.skip(reason="The `input_embeds` when fed don't produce the same results.")
|
|
||||||
def test_beam_sample_generate(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class BioGptModelIntegrationTest(unittest.TestCase):
|
class BioGptModelIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -433,6 +433,10 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||||
|
|
||||||
|
@unittest.skip("The `input_embeds` when fed don't produce the same results.")
|
||||||
|
def test_beam_sample_generate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class MambaIntegrationTests(unittest.TestCase):
|
class MambaIntegrationTests(unittest.TestCase):
|
||||||
|
|||||||
@@ -283,6 +283,12 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
reason="Mamba2 does not support generating with input embeddings (custom cache_position computation)"
|
||||||
|
)
|
||||||
|
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
|
|||||||
@@ -293,15 +293,9 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
||||||
return config, input_ids, attention_mask
|
return config, input_ids, attention_mask
|
||||||
|
|
||||||
@staticmethod
|
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||||
def _get_logits_processor_and_warper_kwargs(
|
logits_processor_kwargs = {}
|
||||||
input_length,
|
return logits_processor_kwargs
|
||||||
forced_bos_token_id=None,
|
|
||||||
forced_eos_token_id=None,
|
|
||||||
):
|
|
||||||
process_kwargs = {}
|
|
||||||
warper_kwargs = {}
|
|
||||||
return process_kwargs, warper_kwargs
|
|
||||||
|
|
||||||
def test_greedy_generate_stereo_outputs(self):
|
def test_greedy_generate_stereo_outputs(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
for model_class in self.greedy_sample_model_classes:
|
||||||
@@ -1483,15 +1477,9 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
|
|
||||||
return output_generate
|
return output_generate
|
||||||
|
|
||||||
@staticmethod
|
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||||
def _get_logits_processor_and_warper_kwargs(
|
logits_processor_kwargs = {}
|
||||||
input_length,
|
return logits_processor_kwargs
|
||||||
forced_bos_token_id=None,
|
|
||||||
forced_eos_token_id=None,
|
|
||||||
):
|
|
||||||
process_kwargs = {}
|
|
||||||
warper_kwargs = {}
|
|
||||||
return process_kwargs, warper_kwargs
|
|
||||||
|
|
||||||
def test_greedy_generate_dict_outputs(self):
|
def test_greedy_generate_dict_outputs(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
for model_class in self.greedy_sample_model_classes:
|
||||||
|
|||||||
@@ -296,15 +296,9 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
|||||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
||||||
return config, input_ids, attention_mask
|
return config, input_ids, attention_mask
|
||||||
|
|
||||||
@staticmethod
|
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||||
def _get_logits_processor_and_warper_kwargs(
|
logits_processor_kwargs = {}
|
||||||
input_length,
|
return logits_processor_kwargs
|
||||||
forced_bos_token_id=None,
|
|
||||||
forced_eos_token_id=None,
|
|
||||||
):
|
|
||||||
process_kwargs = {}
|
|
||||||
warper_kwargs = {}
|
|
||||||
return process_kwargs, warper_kwargs
|
|
||||||
|
|
||||||
def test_greedy_generate_stereo_outputs(self):
|
def test_greedy_generate_stereo_outputs(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
for model_class in self.greedy_sample_model_classes:
|
||||||
@@ -1467,15 +1461,9 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
|
|
||||||
return output_generate
|
return output_generate
|
||||||
|
|
||||||
@staticmethod
|
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||||
def _get_logits_processor_and_warper_kwargs(
|
logits_processor_kwargs = {}
|
||||||
input_length,
|
return logits_processor_kwargs
|
||||||
forced_bos_token_id=None,
|
|
||||||
forced_eos_token_id=None,
|
|
||||||
):
|
|
||||||
process_kwargs = {}
|
|
||||||
warper_kwargs = {}
|
|
||||||
return process_kwargs, warper_kwargs
|
|
||||||
|
|
||||||
def test_greedy_generate_dict_outputs(self):
|
def test_greedy_generate_dict_outputs(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
for model_class in self.greedy_sample_model_classes:
|
||||||
|
|||||||
@@ -413,6 +413,10 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
|||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="RecurrentGemma does not support generating with input embeddings (missing position_ids)")
|
||||||
|
def test_inputs_embeds_matches_input_ids_with_generate(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@slow
|
@slow
|
||||||
|
|||||||
@@ -68,14 +68,7 @@ if is_torch_available():
|
|||||||
set_seed,
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.generation import (
|
from transformers.generation import (
|
||||||
BeamSampleDecoderOnlyOutput,
|
|
||||||
BeamSampleEncoderDecoderOutput,
|
|
||||||
BeamSearchDecoderOnlyOutput,
|
|
||||||
BeamSearchEncoderDecoderOutput,
|
|
||||||
GenerateBeamDecoderOnlyOutput,
|
|
||||||
GenerateBeamEncoderDecoderOutput,
|
|
||||||
GenerateEncoderDecoderOutput,
|
GenerateEncoderDecoderOutput,
|
||||||
PhrasalConstraint,
|
|
||||||
)
|
)
|
||||||
from transformers.generation.logits_process import LogitsProcessor
|
from transformers.generation.logits_process import LogitsProcessor
|
||||||
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, sinusoids
|
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, sinusoids
|
||||||
@@ -419,6 +412,30 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||||
|
# Overwritten from `GenerationTesterMixin`, Whisper needs `"temperature": 0.0` to be able to do beam search
|
||||||
|
logits_processor_kwargs = super()._get_logits_processor_kwargs(do_sample=do_sample)
|
||||||
|
logits_processor_kwargs["temperature"] = 0.0
|
||||||
|
return logits_processor_kwargs
|
||||||
|
|
||||||
|
def _get_beam_kwargs(self, num_return_sequences=1):
|
||||||
|
# Overwritten from `GenerationTesterMixin`, Whisper's `num_return_sequences` differs from the core `generate`
|
||||||
|
beam_kwargs = super()._get_beam_kwargs(num_return_sequences=num_return_sequences)
|
||||||
|
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
|
||||||
|
return beam_kwargs
|
||||||
|
|
||||||
|
def _get_diverse_beam_kwargs(self, num_return_sequences=1):
|
||||||
|
# Overwritten from `GenerationTesterMixin`, Whisper's `num_return_sequences` differs from the core `generate`
|
||||||
|
beam_kwargs = super()._get_diverse_beam_kwargs(num_return_sequences=num_return_sequences)
|
||||||
|
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
|
||||||
|
return beam_kwargs
|
||||||
|
|
||||||
|
def _get_constrained_beam_kwargs(self, num_return_sequences=1):
|
||||||
|
# Overwritten from `GenerationTesterMixin`, Whisper's `num_return_sequences` differs from the core `generate`
|
||||||
|
beam_kwargs = super()._get_constrained_beam_kwargs(num_return_sequences=num_return_sequences)
|
||||||
|
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
|
||||||
|
return beam_kwargs
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = WhisperModelTester(self)
|
self.model_tester = WhisperModelTester(self)
|
||||||
self.config_tester = ConfigTester(self, config_class=WhisperConfig)
|
self.config_tester = ConfigTester(self, config_class=WhisperConfig)
|
||||||
@@ -1551,241 +1568,6 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
def test_longform_generate_multi_batch_cond_prev(self):
|
def test_longform_generate_multi_batch_cond_prev(self):
|
||||||
self._check_longform_generate_multi_batch(condition_on_prev_tokens=True)
|
self._check_longform_generate_multi_batch(condition_on_prev_tokens=True)
|
||||||
|
|
||||||
def test_beam_sample_generate_dict_output(self):
|
|
||||||
# We overwrite test_beam_sample_generate_dict_output in test_utils as
|
|
||||||
# we can only perform beam search if the temperature is set to 0 in Whisper.
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
|
||||||
|
|
||||||
# disable cache
|
|
||||||
config.use_cache = False
|
|
||||||
|
|
||||||
model = WhisperForConditionalGeneration(config).to(torch_device).eval()
|
|
||||||
_, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1])
|
|
||||||
beam_kwargs = self._get_beam_kwargs()
|
|
||||||
|
|
||||||
# With Whisper, we can only perform a beam search if the temperature is set to 0.
|
|
||||||
logits_warper_kwargs["temperature"] = 0
|
|
||||||
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
|
|
||||||
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
|
|
||||||
|
|
||||||
output_generate = self._beam_sample_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
beam_kwargs=beam_kwargs,
|
|
||||||
logits_warper_kwargs=logits_warper_kwargs,
|
|
||||||
output_scores=True,
|
|
||||||
output_logits=True,
|
|
||||||
output_hidden_states=True,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
)
|
|
||||||
if model.config.is_encoder_decoder:
|
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
|
||||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
|
||||||
# Retrocompatibility check
|
|
||||||
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
|
|
||||||
else:
|
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
|
||||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
|
||||||
# Retrocompatibility check
|
|
||||||
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
|
|
||||||
|
|
||||||
self._check_outputs(output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"])
|
|
||||||
|
|
||||||
def test_beam_search_generate_dict_output(self):
|
|
||||||
# We overwrite test_beam_search_generate_dict_output in test_utils as
|
|
||||||
# we can only perform beam search if the temperature is set to 0 in Whisper.
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
|
||||||
|
|
||||||
# disable cache
|
|
||||||
config.use_cache = False
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
|
||||||
input_ids.shape[-1],
|
|
||||||
config.forced_bos_token_id,
|
|
||||||
config.forced_eos_token_id,
|
|
||||||
)
|
|
||||||
beam_kwargs = self._get_beam_kwargs()
|
|
||||||
|
|
||||||
# With Whisper, we can only perform a beam search if the temperature is set to 0.
|
|
||||||
logits_process_kwargs["temperature"] = 0
|
|
||||||
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
|
|
||||||
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
|
|
||||||
|
|
||||||
output_generate = self._beam_search_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
beam_kwargs=beam_kwargs,
|
|
||||||
logits_process_kwargs=logits_process_kwargs,
|
|
||||||
output_scores=True,
|
|
||||||
output_logits=True,
|
|
||||||
output_hidden_states=True,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
)
|
|
||||||
if model.config.is_encoder_decoder:
|
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
|
||||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
|
||||||
# Retrocompatibility check
|
|
||||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
|
||||||
else:
|
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
|
||||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
|
||||||
# Retrocompatibility check
|
|
||||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
|
||||||
|
|
||||||
self._check_outputs(
|
|
||||||
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
|
||||||
# We overwrite test_beam_search_generate_dict_outputs_use_cache in test_utils as
|
|
||||||
# we can only perform beam search if the temperature is set to 0 in Whisper.
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
# enable cache
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
|
||||||
|
|
||||||
if not hasattr(config, "use_cache"):
|
|
||||||
self.skipTest("This model doesn't support caching")
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
|
||||||
input_ids.shape[-1],
|
|
||||||
config.forced_bos_token_id,
|
|
||||||
config.forced_eos_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
beam_kwargs = self._get_beam_kwargs()
|
|
||||||
|
|
||||||
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
|
|
||||||
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
|
|
||||||
|
|
||||||
config.use_cache = True
|
|
||||||
config.is_decoder = True
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
output_generate = self._beam_search_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
beam_kwargs=beam_kwargs,
|
|
||||||
logits_process_kwargs=logits_process_kwargs,
|
|
||||||
output_scores=True,
|
|
||||||
output_logits=True,
|
|
||||||
output_hidden_states=True,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
|
||||||
else:
|
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
|
||||||
self._check_outputs(
|
|
||||||
output_generate, input_ids, model.config, use_cache=True, num_return_sequences=beam_kwargs["num_beams"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_group_beam_search_generate_dict_output(self):
|
|
||||||
# We overwrite test_group_beam_search_generate_dict_output in test_utils as
|
|
||||||
# we can only perform beam search if the temperature is set to 0 in Whisper.
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
|
||||||
config.use_cache = False
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
|
||||||
input_ids.shape[-1],
|
|
||||||
config.forced_bos_token_id,
|
|
||||||
config.forced_eos_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
beam_kwargs = self._get_diverse_beam_kwargs()
|
|
||||||
|
|
||||||
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
|
|
||||||
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
|
|
||||||
|
|
||||||
output_generate = self._group_beam_search_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
beam_kwargs=beam_kwargs,
|
|
||||||
logits_process_kwargs=logits_process_kwargs,
|
|
||||||
output_scores=True,
|
|
||||||
output_logits=True,
|
|
||||||
output_hidden_states=True,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
)
|
|
||||||
if model.config.is_encoder_decoder:
|
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
|
||||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
|
||||||
# Retrocompatibility check
|
|
||||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
|
||||||
else:
|
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
|
||||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
|
||||||
# Retrocompatibility check
|
|
||||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
|
||||||
|
|
||||||
self._check_outputs(
|
|
||||||
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_constrained_beam_search_generate_dict_output(self):
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
|
||||||
|
|
||||||
# disable cache
|
|
||||||
config.use_cache = False
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
|
||||||
input_ids.shape[-1],
|
|
||||||
config.forced_bos_token_id,
|
|
||||||
config.forced_eos_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sample constraints
|
|
||||||
min_id = 3
|
|
||||||
max_id = model.config.vocab_size
|
|
||||||
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
|
|
||||||
constraints = [
|
|
||||||
PhrasalConstraint(force_tokens),
|
|
||||||
]
|
|
||||||
|
|
||||||
beam_kwargs = self._get_constrained_beam_kwargs()
|
|
||||||
output_generate = self._constrained_beam_search_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
constraints=constraints,
|
|
||||||
beam_kwargs=beam_kwargs,
|
|
||||||
logits_process_kwargs=logits_process_kwargs,
|
|
||||||
output_scores=True,
|
|
||||||
output_logits=True,
|
|
||||||
output_hidden_states=True,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
|
||||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
|
||||||
# Retrocompatibility check
|
|
||||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
|
||||||
else:
|
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
|
||||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
|
||||||
# Retrocompatibility check
|
|
||||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
|
||||||
|
|
||||||
self._check_outputs(
|
|
||||||
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_return_sequences"]
|
|
||||||
)
|
|
||||||
|
|
||||||
@is_flaky() # TODO (joao, sanchit): fails ~9% of the times. Does the original test have the same issue?
|
@is_flaky() # TODO (joao, sanchit): fails ~9% of the times. Does the original test have the same issue?
|
||||||
def test_custom_4d_attention_mask(self):
|
def test_custom_4d_attention_mask(self):
|
||||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ OBJECTS_TO_IGNORE = [
|
|||||||
# Deprecated
|
# Deprecated
|
||||||
"InputExample",
|
"InputExample",
|
||||||
"InputFeatures",
|
"InputFeatures",
|
||||||
|
"LogitsWarper",
|
||||||
# Signature is *args/**kwargs
|
# Signature is *args/**kwargs
|
||||||
"TFSequenceSummary",
|
"TFSequenceSummary",
|
||||||
"TFBertTokenizer",
|
"TFBertTokenizer",
|
||||||
|
|||||||
@@ -932,6 +932,7 @@ DEPRECATED_OBJECTS = [
|
|||||||
"LineByLineTextDataset",
|
"LineByLineTextDataset",
|
||||||
"LineByLineWithRefDataset",
|
"LineByLineWithRefDataset",
|
||||||
"LineByLineWithSOPTextDataset",
|
"LineByLineWithSOPTextDataset",
|
||||||
|
"LogitsWarper",
|
||||||
"NerPipeline",
|
"NerPipeline",
|
||||||
"PretrainedBartModel",
|
"PretrainedBartModel",
|
||||||
"PretrainedFSMTModel",
|
"PretrainedFSMTModel",
|
||||||
|
|||||||
Reference in New Issue
Block a user