Generate: inner decoding methods are no longer public (#29437)

This commit is contained in:
Joao Gante
2024-03-05 10:27:36 +00:00
committed by GitHub
parent 4d892b7297
commit 87a0783dde
11 changed files with 117 additions and 104 deletions

View File

@@ -389,3 +389,6 @@ just like in multinomial sampling. However, in assisted decoding, reducing the t
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Alice and Bob are going to the same party. It is a small party, in a small'] ['Alice and Bob are going to the same party. It is a small party, in a small']
``` ```
Alternativelly, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed
to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259).

View File

@@ -16,16 +16,7 @@ rendered properly in your Markdown viewer.
# Utilities for Generation # Utilities for Generation
This page lists all the utility functions used by [`~generation.GenerationMixin.generate`], This page lists all the utility functions used by [`~generation.GenerationMixin.generate`].
[`~generation.GenerationMixin.greedy_search`],
[`~generation.GenerationMixin.contrastive_search`],
[`~generation.GenerationMixin.sample`],
[`~generation.GenerationMixin.beam_search`],
[`~generation.GenerationMixin.beam_sample`],
[`~generation.GenerationMixin.group_beam_search`], and
[`~generation.GenerationMixin.constrained_beam_search`].
Most of those are only useful if you are studying the code of the generate methods in the library.
## Generate Outputs ## Generate Outputs

View File

@@ -43,13 +43,6 @@ like token streaming.
[[autodoc]] generation.GenerationMixin [[autodoc]] generation.GenerationMixin
- generate - generate
- compute_transition_scores - compute_transition_scores
- greedy_search
- sample
- beam_search
- beam_sample
- contrastive_search
- group_beam_search
- constrained_beam_search
## TFGenerationMixin ## TFGenerationMixin

View File

@@ -17,15 +17,6 @@ rendered properly in your Markdown viewer.
# 発電用ユーティリティ # 発電用ユーティリティ
このページには、[`~generation.GenerationMixin.generate`] で使用されるすべてのユーティリティ関数がリストされています。 このページには、[`~generation.GenerationMixin.generate`] で使用されるすべてのユーティリティ関数がリストされています。
[`~generation.GenerationMixin.greedy_search`],
[`~generation.GenerationMixin.contrastive_search`],
[`~generation.GenerationMixin.sample`],
[`~generation.GenerationMixin.beam_search`],
[`~generation.GenerationMixin.beam_sample`],
[`~generation.GenerationMixin.group_beam_search`]、および
[`~generation.GenerationMixin.constrained_beam_search`]。
これらのほとんどは、ライブラリ内の生成メソッドのコードを学習する場合にのみ役に立ちます。
## 出力を生成する ## 出力を生成する

View File

@@ -43,13 +43,6 @@ rendered properly in your Markdown viewer.
[[autodoc]] generation.GenerationMixin [[autodoc]] generation.GenerationMixin
- generate - generate
- compute_transition_scores - compute_transition_scores
- greedy_search
- sample
- beam_search
- beam_sample
- contrastive_search
- group_beam_search
- constrained_beam_search
## TFGenerationMixin ## TFGenerationMixin

View File

@@ -16,16 +16,7 @@ rendered properly in your Markdown viewer.
# 用于生成的工具 # 用于生成的工具
此页面列出了所有由 [`~generation.GenerationMixin.generate`], 此页面列出了所有由 [`~generation.GenerationMixin.generate`]
[`~generation.GenerationMixin.greedy_search`],
[`~generation.GenerationMixin.contrastive_search`],
[`~generation.GenerationMixin.sample`],
[`~generation.GenerationMixin.beam_search`],
[`~generation.GenerationMixin.beam_sample`],
[`~generation.GenerationMixin.group_beam_search`], 和
[`~generation.GenerationMixin.constrained_beam_search`]使用的实用函数。
其中大多数仅在您研究库中生成方法的代码时才有用。
## 生成输出 ## 生成输出

View File

@@ -38,13 +38,6 @@ rendered properly in your Markdown viewer.
[[autodoc]] generation.GenerationMixin [[autodoc]] generation.GenerationMixin
- generate - generate
- compute_transition_scores - compute_transition_scores
- greedy_search
- sample
- beam_search
- beam_sample
- contrastive_search
- group_beam_search
- constrained_beam_search
## TFGenerationMixin ## TFGenerationMixin

View File

@@ -43,22 +43,22 @@ class GenerationConfig(PushToHubMixin):
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
for text-decoder, text-to-text, speech-to-text, and vision-to-text models: for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
- *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and - *greedy decoding* by calling [`~generation.GenerationMixin._greedy_search`] if `num_beams=1` and
`do_sample=False` `do_sample=False`
- *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0.` - *contrastive search* by calling [`~generation.GenerationMixin._contrastive_search`] if `penalty_alpha>0.`
and `top_k>1` and `top_k>1`
- *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and - *multinomial sampling* by calling [`~generation.GenerationMixin._sample`] if `num_beams=1` and
`do_sample=True` `do_sample=True`
- *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and - *beam-search decoding* by calling [`~generation.GenerationMixin._beam_search`] if `num_beams>1` and
`do_sample=False` `do_sample=False`
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if - *beam-search multinomial sampling* by calling [`~generation.GenerationMixin._beam_sample`] if
`num_beams>1` and `do_sample=True` `num_beams>1` and `do_sample=True`
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if - *diverse beam-search decoding* by calling [`~generation.GenerationMixin._group_beam_search`], if
`num_beams>1` and `num_beam_groups>1` `num_beams>1` and `num_beam_groups>1`
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if - *constrained beam-search decoding* by calling [`~generation.GenerationMixin._constrained_beam_search`], if
`constraints!=None` or `force_words_ids!=None` `constraints!=None` or `force_words_ids!=None`
- *assisted decoding* by calling [`~generation.GenerationMixin.assisted_decoding`], if - *assisted decoding* by calling [`~generation.GenerationMixin._assisted_decoding`], if
`assistant_model` is passed to `.generate()` `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
You do not need to call any of the above methods directly. Pass custom parameter values to '.generate()'. To learn You do not need to call any of the above methods directly. Pass custom parameter values to '.generate()'. To learn
more about decoding strategies refer to the [text generation strategies guide](../generation_strategies). more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).

View File

@@ -347,20 +347,22 @@ class GenerationMixin:
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`]. A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
The class exposes [`~generation.GenerationMixin.generate`], which can be used for: The class exposes [`~generation.GenerationMixin.generate`], which can be used for:
- *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and - *greedy decoding* by calling [`~generation.GenerationMixin._greedy_search`] if `num_beams=1` and
`do_sample=False` `do_sample=False`
- *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0` and - *contrastive search* by calling [`~generation.GenerationMixin._contrastive_search`] if `penalty_alpha>0` and
`top_k>1` `top_k>1`
- *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and - *multinomial sampling* by calling [`~generation.GenerationMixin._sample`] if `num_beams=1` and
`do_sample=True` `do_sample=True`
- *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and - *beam-search decoding* by calling [`~generation.GenerationMixin._beam_search`] if `num_beams>1` and
`do_sample=False` `do_sample=False`
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if `num_beams>1` - *beam-search multinomial sampling* by calling [`~generation.GenerationMixin._beam_sample`] if `num_beams>1`
and `do_sample=True` and `do_sample=True`
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if `num_beams>1` - *diverse beam-search decoding* by calling [`~generation.GenerationMixin._group_beam_search`], if `num_beams>1`
and `num_beam_groups>1` and `num_beam_groups>1`
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if - *constrained beam-search decoding* by calling [`~generation.GenerationMixin._constrained_beam_search`], if
`constraints!=None` or `force_words_ids!=None` `constraints!=None` or `force_words_ids!=None`
- *assisted decoding* by calling [`~generation.GenerationMixin._assisted_decoding`], if
`assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
You do not need to call any of the above methods directly. Pass custom parameter values to 'generate' instead. To You do not need to call any of the above methods directly. Pass custom parameter values to 'generate' instead. To
learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies). learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
@@ -1547,7 +1549,7 @@ class GenerationMixin:
) )
if generation_mode == GenerationMode.GREEDY_SEARCH: if generation_mode == GenerationMode.GREEDY_SEARCH:
# 11. run greedy search # 11. run greedy search
result = self.greedy_search( result = self._greedy_search(
input_ids, input_ids,
logits_processor=prepared_logits_processor, logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria, stopping_criteria=prepared_stopping_criteria,
@@ -1565,7 +1567,7 @@ class GenerationMixin:
if not model_kwargs["use_cache"]: if not model_kwargs["use_cache"]:
raise ValueError("Contrastive search requires `use_cache=True`") raise ValueError("Contrastive search requires `use_cache=True`")
result = self.contrastive_search( result = self._contrastive_search(
input_ids, input_ids,
top_k=generation_config.top_k, top_k=generation_config.top_k,
penalty_alpha=generation_config.penalty_alpha, penalty_alpha=generation_config.penalty_alpha,
@@ -1595,7 +1597,7 @@ class GenerationMixin:
) )
# 13. run sample # 13. run sample
result = self.sample( result = self._sample(
input_ids, input_ids,
logits_processor=prepared_logits_processor, logits_processor=prepared_logits_processor,
logits_warper=logits_warper, logits_warper=logits_warper,
@@ -1629,7 +1631,7 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) )
# 13. run beam search # 13. run beam search
result = self.beam_search( result = self._beam_search(
input_ids, input_ids,
beam_scorer, beam_scorer,
logits_processor=prepared_logits_processor, logits_processor=prepared_logits_processor,
@@ -1668,7 +1670,7 @@ class GenerationMixin:
) )
# 14. run beam sample # 14. run beam sample
result = self.beam_sample( result = self._beam_sample(
input_ids, input_ids,
beam_scorer, beam_scorer,
logits_processor=prepared_logits_processor, logits_processor=prepared_logits_processor,
@@ -1703,7 +1705,7 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) )
# 13. run beam search # 13. run beam search
result = self.group_beam_search( result = self._group_beam_search(
input_ids, input_ids,
beam_scorer, beam_scorer,
logits_processor=prepared_logits_processor, logits_processor=prepared_logits_processor,
@@ -1777,7 +1779,7 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) )
# 13. run beam search # 13. run beam search
result = self.constrained_beam_search( result = self._constrained_beam_search(
input_ids, input_ids,
constrained_beam_scorer=constrained_beam_scorer, constrained_beam_scorer=constrained_beam_scorer,
logits_processor=prepared_logits_processor, logits_processor=prepared_logits_processor,
@@ -1801,8 +1803,15 @@ class GenerationMixin:
return result return result
def contrastive_search(self, *args, **kwargs):
logger.warning_once(
"Calling `contrastive_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._contrastive_search(*args, **kwargs)
@torch.no_grad() @torch.no_grad()
def contrastive_search( def _contrastive_search(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
top_k: Optional[int] = 1, top_k: Optional[int] = 1,
@@ -1828,7 +1837,7 @@ class GenerationMixin:
<Tip warning={true}> <Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.contrastive_search`] directly. Use In most cases, you do not need to call [`~generation.GenerationMixin._contrastive_search`] directly. Use
generate() instead. For an overview of generation strategies and code examples, check the [following generate() instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies). guide](../generation_strategies).
@@ -1902,7 +1911,7 @@ class GenerationMixin:
>>> input_prompt = "DeepMind Company is" >>> input_prompt = "DeepMind Company is"
>>> input_ids = tokenizer(input_prompt, return_tensors="pt") >>> input_ids = tokenizer(input_prompt, return_tensors="pt")
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=64)]) >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=64)])
>>> outputs = model.contrastive_search( >>> outputs = model._contrastive_search(
... **input_ids, penalty_alpha=0.6, top_k=4, stopping_criteria=stopping_criteria ... **input_ids, penalty_alpha=0.6, top_k=4, stopping_criteria=stopping_criteria
... ) ... )
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
@@ -2243,7 +2252,14 @@ class GenerationMixin:
else: else:
return input_ids return input_ids
def greedy_search( def greedy_search(self, *args, **kwargs):
logger.warning_once(
"Calling `greedy_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._greedy_search(*args, **kwargs)
def _greedy_search(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
@@ -2266,7 +2282,7 @@ class GenerationMixin:
<Tip warning={true}> <Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.greedy_search`] directly. Use generate() In most cases, you do not need to call [`~generation.GenerationMixin._greedy_search`] directly. Use generate()
instead. For an overview of generation strategies and code examples, check the [following instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies). guide](../generation_strategies).
@@ -2348,7 +2364,7 @@ class GenerationMixin:
... ) ... )
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> outputs = model.greedy_search( >>> outputs = model._greedy_search(
... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria ... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
... ) ... )
@@ -2514,7 +2530,14 @@ class GenerationMixin:
else: else:
return input_ids return input_ids
def sample( def sample(self, *args, **kwargs):
logger.warning_once(
"Calling `sample` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._sample(*args, **kwargs)
def _sample(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
@@ -2538,7 +2561,7 @@ class GenerationMixin:
<Tip warning={true}> <Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead. In most cases, you do not need to call [`~generation.GenerationMixin._sample`] directly. Use generate() instead.
For an overview of generation strategies and code examples, check the [following For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies). guide](../generation_strategies).
@@ -2635,7 +2658,7 @@ class GenerationMixin:
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
>>> outputs = model.sample( >>> outputs = model._sample(
... input_ids, ... input_ids,
... logits_processor=logits_processor, ... logits_processor=logits_processor,
... logits_warper=logits_warper, ... logits_warper=logits_warper,
@@ -2832,7 +2855,14 @@ class GenerationMixin:
past_key_values.reorder_cache(beam_idx) past_key_values.reorder_cache(beam_idx)
return past_key_values return past_key_values
def beam_search( def beam_search(self, *args, **kwargs):
logger.warning_once(
"Calling `beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._beam_search(*args, **kwargs)
def _beam_search(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
beam_scorer: BeamScorer, beam_scorer: BeamScorer,
@@ -2856,7 +2886,7 @@ class GenerationMixin:
<Tip warning={true}> <Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.beam_search`] directly. Use generate() In most cases, you do not need to call [`~generation.GenerationMixin._beam_search`] directly. Use generate()
instead. For an overview of generation strategies and code examples, check the [following instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies). guide](../generation_strategies).
@@ -2958,7 +2988,7 @@ class GenerationMixin:
... ] ... ]
... ) ... )
>>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) >>> outputs = model._beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Wie alt bist du?'] ['Wie alt bist du?']
@@ -3214,7 +3244,14 @@ class GenerationMixin:
else: else:
return sequence_outputs["sequences"] return sequence_outputs["sequences"]
def beam_sample( def beam_sample(self, *args, **kwargs):
logger.warning_once(
"Calling `beam_sample` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._beam_sample(*args, **kwargs)
def _beam_sample(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
beam_scorer: BeamScorer, beam_scorer: BeamScorer,
@@ -3238,7 +3275,7 @@ class GenerationMixin:
<Tip warning={true}> <Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.beam_sample`] directly. Use generate() In most cases, you do not need to call [`~generation.GenerationMixin._beam_sample`] directly. Use generate()
instead. For an overview of generation strategies and code examples, check the [following instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies). guide](../generation_strategies).
@@ -3346,7 +3383,7 @@ class GenerationMixin:
... ] ... ]
... ) ... )
>>> outputs = model.beam_sample( >>> outputs = model._beam_sample(
... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs ... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs
... ) ... )
@@ -3561,7 +3598,14 @@ class GenerationMixin:
else: else:
return sequence_outputs["sequences"] return sequence_outputs["sequences"]
def group_beam_search( def group_beam_search(self, *args, **kwargs):
logger.warning_once(
"Calling `group_beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._group_beam_search(*args, **kwargs)
def _group_beam_search(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
beam_scorer: BeamScorer, beam_scorer: BeamScorer,
@@ -3584,7 +3628,7 @@ class GenerationMixin:
<Tip warning={true}> <Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.group_beam_search`] directly. Use In most cases, you do not need to call [`~generation.GenerationMixin._group_beam_search`] directly. Use
generate() instead. For an overview of generation strategies and code examples, check the [following generate() instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies). guide](../generation_strategies).
@@ -3686,7 +3730,7 @@ class GenerationMixin:
... ] ... ]
... ) ... )
>>> outputs = model.group_beam_search( >>> outputs = model._group_beam_search(
... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs ... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs
... ) ... )
@@ -3958,7 +4002,14 @@ class GenerationMixin:
else: else:
return sequence_outputs["sequences"] return sequence_outputs["sequences"]
def constrained_beam_search( def constrained_beam_search(self, *args, **kwargs):
logger.warning_once(
"Calling `constrained_beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._constrained_beam_search(*args, **kwargs)
def _constrained_beam_search(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
constrained_beam_scorer: ConstrainedBeamSearchScorer, constrained_beam_scorer: ConstrainedBeamSearchScorer,
@@ -3981,7 +4032,7 @@ class GenerationMixin:
<Tip warning={true}> <Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.constrained_beam_search`] directly. Use In most cases, you do not need to call [`~generation.GenerationMixin._constrained_beam_search`] directly. Use
generate() instead. For an overview of generation strategies and code examples, check the [following generate() instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies). guide](../generation_strategies).
@@ -4088,7 +4139,7 @@ class GenerationMixin:
... ] ... ]
... ) ... )
>>> outputs = model.constrained_beam_search( >>> outputs = model._constrained_beam_search(
... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs ... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
... ) ... )
@@ -4311,7 +4362,14 @@ class GenerationMixin:
else: else:
return sequence_outputs["sequences"] return sequence_outputs["sequences"]
def assisted_decoding( def assisted_decoding(self, *args, **kwargs):
logger.warning_once(
"Calling `_assisted_decoding` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead.",
)
return self._assisted_decoding(*args, **kwargs)
def _assisted_decoding(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
candidate_generator: Optional["CandidateGenerator"] = None, candidate_generator: Optional["CandidateGenerator"] = None,
@@ -4338,7 +4396,7 @@ class GenerationMixin:
<Tip warning={true}> <Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.candidate_decoding`] directly. Use In most cases, you do not need to call [`~generation.GenerationMixin._assisted_decoding`] directly. Use
generate() instead. For an overview of generation strategies and code examples, check the [following generate() instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies). guide](../generation_strategies).
@@ -4429,7 +4487,7 @@ class GenerationMixin:
... logits_processor=logits_processor, ... logits_processor=logits_processor,
... model_kwargs={}, ... model_kwargs={},
... ) ... )
>>> outputs = model.assisted_decoding( >>> outputs = model._assisted_decoding(
... input_ids, ... input_ids,
... candidate_generator=candidate_generator, ... candidate_generator=candidate_generator,
... logits_processor=logits_processor, ... logits_processor=logits_processor,

View File

@@ -1336,7 +1336,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
) )
# 11. run greedy search # 11. run greedy search
outputs = self.greedy_search( outputs = self._greedy_search(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
@@ -1361,7 +1361,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
) )
# 12. run sample # 12. run sample
outputs = self.sample( outputs = self._sample(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
logits_warper=logits_warper, logits_warper=logits_warper,
@@ -2402,7 +2402,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
) )
# 11. run greedy search # 11. run greedy search
outputs = self.greedy_search( outputs = self._greedy_search(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
@@ -2428,7 +2428,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
) )
# 12. run sample # 12. run sample
outputs = self.sample( outputs = self._sample(
input_ids, input_ids,
logits_processor=logits_processor, logits_processor=logits_processor,
logits_warper=logits_warper, logits_warper=logits_warper,

View File

@@ -1539,7 +1539,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
" greedy search." " greedy search."
) )
return self.greedy_search( return self._greedy_search(
input_ids, input_ids,
logits_processor=pre_processor, logits_processor=pre_processor,
max_length=generation_config.max_length, max_length=generation_config.max_length,
@@ -1559,7 +1559,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
num_beam_hyps_to_keep=generation_config.num_return_sequences, num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length, max_length=generation_config.max_length,
) )
return self.beam_search( return self._beam_search(
input_ids, input_ids,
beam_scorer, beam_scorer,
logits_processor=pre_processor, logits_processor=pre_processor,