Generate: inner decoding methods are no longer public (#29437)
This commit is contained in:
@@ -389,3 +389,6 @@ just like in multinomial sampling. However, in assisted decoding, reducing the t
|
|||||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
['Alice and Bob are going to the same party. It is a small party, in a small']
|
['Alice and Bob are going to the same party. It is a small party, in a small']
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Alternativelly, you can also set the `prompt_lookup_num_tokens` to trigger n-gram based assisted decoding, as opposed
|
||||||
|
to model based assisted decoding. You can read more about it [here](https://twitter.com/joao_gante/status/1747322413006643259).
|
||||||
|
|||||||
@@ -16,16 +16,7 @@ rendered properly in your Markdown viewer.
|
|||||||
|
|
||||||
# Utilities for Generation
|
# Utilities for Generation
|
||||||
|
|
||||||
This page lists all the utility functions used by [`~generation.GenerationMixin.generate`],
|
This page lists all the utility functions used by [`~generation.GenerationMixin.generate`].
|
||||||
[`~generation.GenerationMixin.greedy_search`],
|
|
||||||
[`~generation.GenerationMixin.contrastive_search`],
|
|
||||||
[`~generation.GenerationMixin.sample`],
|
|
||||||
[`~generation.GenerationMixin.beam_search`],
|
|
||||||
[`~generation.GenerationMixin.beam_sample`],
|
|
||||||
[`~generation.GenerationMixin.group_beam_search`], and
|
|
||||||
[`~generation.GenerationMixin.constrained_beam_search`].
|
|
||||||
|
|
||||||
Most of those are only useful if you are studying the code of the generate methods in the library.
|
|
||||||
|
|
||||||
## Generate Outputs
|
## Generate Outputs
|
||||||
|
|
||||||
|
|||||||
@@ -43,13 +43,6 @@ like token streaming.
|
|||||||
[[autodoc]] generation.GenerationMixin
|
[[autodoc]] generation.GenerationMixin
|
||||||
- generate
|
- generate
|
||||||
- compute_transition_scores
|
- compute_transition_scores
|
||||||
- greedy_search
|
|
||||||
- sample
|
|
||||||
- beam_search
|
|
||||||
- beam_sample
|
|
||||||
- contrastive_search
|
|
||||||
- group_beam_search
|
|
||||||
- constrained_beam_search
|
|
||||||
|
|
||||||
## TFGenerationMixin
|
## TFGenerationMixin
|
||||||
|
|
||||||
|
|||||||
@@ -17,15 +17,6 @@ rendered properly in your Markdown viewer.
|
|||||||
# 発電用ユーティリティ
|
# 発電用ユーティリティ
|
||||||
|
|
||||||
このページには、[`~generation.GenerationMixin.generate`] で使用されるすべてのユーティリティ関数がリストされています。
|
このページには、[`~generation.GenerationMixin.generate`] で使用されるすべてのユーティリティ関数がリストされています。
|
||||||
[`~generation.GenerationMixin.greedy_search`],
|
|
||||||
[`~generation.GenerationMixin.contrastive_search`],
|
|
||||||
[`~generation.GenerationMixin.sample`],
|
|
||||||
[`~generation.GenerationMixin.beam_search`],
|
|
||||||
[`~generation.GenerationMixin.beam_sample`],
|
|
||||||
[`~generation.GenerationMixin.group_beam_search`]、および
|
|
||||||
[`~generation.GenerationMixin.constrained_beam_search`]。
|
|
||||||
|
|
||||||
これらのほとんどは、ライブラリ内の生成メソッドのコードを学習する場合にのみ役に立ちます。
|
|
||||||
|
|
||||||
## 出力を生成する
|
## 出力を生成する
|
||||||
|
|
||||||
|
|||||||
@@ -43,13 +43,6 @@ rendered properly in your Markdown viewer.
|
|||||||
[[autodoc]] generation.GenerationMixin
|
[[autodoc]] generation.GenerationMixin
|
||||||
- generate
|
- generate
|
||||||
- compute_transition_scores
|
- compute_transition_scores
|
||||||
- greedy_search
|
|
||||||
- sample
|
|
||||||
- beam_search
|
|
||||||
- beam_sample
|
|
||||||
- contrastive_search
|
|
||||||
- group_beam_search
|
|
||||||
- constrained_beam_search
|
|
||||||
|
|
||||||
## TFGenerationMixin
|
## TFGenerationMixin
|
||||||
|
|
||||||
|
|||||||
@@ -16,16 +16,7 @@ rendered properly in your Markdown viewer.
|
|||||||
|
|
||||||
# 用于生成的工具
|
# 用于生成的工具
|
||||||
|
|
||||||
此页面列出了所有由 [`~generation.GenerationMixin.generate`],
|
此页面列出了所有由 [`~generation.GenerationMixin.generate`]。
|
||||||
[`~generation.GenerationMixin.greedy_search`],
|
|
||||||
[`~generation.GenerationMixin.contrastive_search`],
|
|
||||||
[`~generation.GenerationMixin.sample`],
|
|
||||||
[`~generation.GenerationMixin.beam_search`],
|
|
||||||
[`~generation.GenerationMixin.beam_sample`],
|
|
||||||
[`~generation.GenerationMixin.group_beam_search`], 和
|
|
||||||
[`~generation.GenerationMixin.constrained_beam_search`]使用的实用函数。
|
|
||||||
|
|
||||||
其中大多数仅在您研究库中生成方法的代码时才有用。
|
|
||||||
|
|
||||||
## 生成输出
|
## 生成输出
|
||||||
|
|
||||||
|
|||||||
@@ -38,13 +38,6 @@ rendered properly in your Markdown viewer.
|
|||||||
[[autodoc]] generation.GenerationMixin
|
[[autodoc]] generation.GenerationMixin
|
||||||
- generate
|
- generate
|
||||||
- compute_transition_scores
|
- compute_transition_scores
|
||||||
- greedy_search
|
|
||||||
- sample
|
|
||||||
- beam_search
|
|
||||||
- beam_sample
|
|
||||||
- contrastive_search
|
|
||||||
- group_beam_search
|
|
||||||
- constrained_beam_search
|
|
||||||
|
|
||||||
## TFGenerationMixin
|
## TFGenerationMixin
|
||||||
|
|
||||||
|
|||||||
@@ -43,22 +43,22 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
|
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
|
||||||
for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
|
for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
|
||||||
|
|
||||||
- *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and
|
- *greedy decoding* by calling [`~generation.GenerationMixin._greedy_search`] if `num_beams=1` and
|
||||||
`do_sample=False`
|
`do_sample=False`
|
||||||
- *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0.`
|
- *contrastive search* by calling [`~generation.GenerationMixin._contrastive_search`] if `penalty_alpha>0.`
|
||||||
and `top_k>1`
|
and `top_k>1`
|
||||||
- *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and
|
- *multinomial sampling* by calling [`~generation.GenerationMixin._sample`] if `num_beams=1` and
|
||||||
`do_sample=True`
|
`do_sample=True`
|
||||||
- *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and
|
- *beam-search decoding* by calling [`~generation.GenerationMixin._beam_search`] if `num_beams>1` and
|
||||||
`do_sample=False`
|
`do_sample=False`
|
||||||
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if
|
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin._beam_sample`] if
|
||||||
`num_beams>1` and `do_sample=True`
|
`num_beams>1` and `do_sample=True`
|
||||||
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if
|
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin._group_beam_search`], if
|
||||||
`num_beams>1` and `num_beam_groups>1`
|
`num_beams>1` and `num_beam_groups>1`
|
||||||
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if
|
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin._constrained_beam_search`], if
|
||||||
`constraints!=None` or `force_words_ids!=None`
|
`constraints!=None` or `force_words_ids!=None`
|
||||||
- *assisted decoding* by calling [`~generation.GenerationMixin.assisted_decoding`], if
|
- *assisted decoding* by calling [`~generation.GenerationMixin._assisted_decoding`], if
|
||||||
`assistant_model` is passed to `.generate()`
|
`assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
|
||||||
|
|
||||||
You do not need to call any of the above methods directly. Pass custom parameter values to '.generate()'. To learn
|
You do not need to call any of the above methods directly. Pass custom parameter values to '.generate()'. To learn
|
||||||
more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
|
more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
|
||||||
|
|||||||
@@ -347,20 +347,22 @@ class GenerationMixin:
|
|||||||
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
|
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
|
||||||
|
|
||||||
The class exposes [`~generation.GenerationMixin.generate`], which can be used for:
|
The class exposes [`~generation.GenerationMixin.generate`], which can be used for:
|
||||||
- *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and
|
- *greedy decoding* by calling [`~generation.GenerationMixin._greedy_search`] if `num_beams=1` and
|
||||||
`do_sample=False`
|
`do_sample=False`
|
||||||
- *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0` and
|
- *contrastive search* by calling [`~generation.GenerationMixin._contrastive_search`] if `penalty_alpha>0` and
|
||||||
`top_k>1`
|
`top_k>1`
|
||||||
- *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and
|
- *multinomial sampling* by calling [`~generation.GenerationMixin._sample`] if `num_beams=1` and
|
||||||
`do_sample=True`
|
`do_sample=True`
|
||||||
- *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and
|
- *beam-search decoding* by calling [`~generation.GenerationMixin._beam_search`] if `num_beams>1` and
|
||||||
`do_sample=False`
|
`do_sample=False`
|
||||||
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if `num_beams>1`
|
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin._beam_sample`] if `num_beams>1`
|
||||||
and `do_sample=True`
|
and `do_sample=True`
|
||||||
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if `num_beams>1`
|
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin._group_beam_search`], if `num_beams>1`
|
||||||
and `num_beam_groups>1`
|
and `num_beam_groups>1`
|
||||||
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if
|
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin._constrained_beam_search`], if
|
||||||
`constraints!=None` or `force_words_ids!=None`
|
`constraints!=None` or `force_words_ids!=None`
|
||||||
|
- *assisted decoding* by calling [`~generation.GenerationMixin._assisted_decoding`], if
|
||||||
|
`assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
|
||||||
|
|
||||||
You do not need to call any of the above methods directly. Pass custom parameter values to 'generate' instead. To
|
You do not need to call any of the above methods directly. Pass custom parameter values to 'generate' instead. To
|
||||||
learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
|
learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
|
||||||
@@ -1547,7 +1549,7 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
if generation_mode == GenerationMode.GREEDY_SEARCH:
|
if generation_mode == GenerationMode.GREEDY_SEARCH:
|
||||||
# 11. run greedy search
|
# 11. run greedy search
|
||||||
result = self.greedy_search(
|
result = self._greedy_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
logits_processor=prepared_logits_processor,
|
logits_processor=prepared_logits_processor,
|
||||||
stopping_criteria=prepared_stopping_criteria,
|
stopping_criteria=prepared_stopping_criteria,
|
||||||
@@ -1565,7 +1567,7 @@ class GenerationMixin:
|
|||||||
if not model_kwargs["use_cache"]:
|
if not model_kwargs["use_cache"]:
|
||||||
raise ValueError("Contrastive search requires `use_cache=True`")
|
raise ValueError("Contrastive search requires `use_cache=True`")
|
||||||
|
|
||||||
result = self.contrastive_search(
|
result = self._contrastive_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
top_k=generation_config.top_k,
|
top_k=generation_config.top_k,
|
||||||
penalty_alpha=generation_config.penalty_alpha,
|
penalty_alpha=generation_config.penalty_alpha,
|
||||||
@@ -1595,7 +1597,7 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 13. run sample
|
# 13. run sample
|
||||||
result = self.sample(
|
result = self._sample(
|
||||||
input_ids,
|
input_ids,
|
||||||
logits_processor=prepared_logits_processor,
|
logits_processor=prepared_logits_processor,
|
||||||
logits_warper=logits_warper,
|
logits_warper=logits_warper,
|
||||||
@@ -1629,7 +1631,7 @@ class GenerationMixin:
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
# 13. run beam search
|
# 13. run beam search
|
||||||
result = self.beam_search(
|
result = self._beam_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
beam_scorer,
|
beam_scorer,
|
||||||
logits_processor=prepared_logits_processor,
|
logits_processor=prepared_logits_processor,
|
||||||
@@ -1668,7 +1670,7 @@ class GenerationMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 14. run beam sample
|
# 14. run beam sample
|
||||||
result = self.beam_sample(
|
result = self._beam_sample(
|
||||||
input_ids,
|
input_ids,
|
||||||
beam_scorer,
|
beam_scorer,
|
||||||
logits_processor=prepared_logits_processor,
|
logits_processor=prepared_logits_processor,
|
||||||
@@ -1703,7 +1705,7 @@ class GenerationMixin:
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
# 13. run beam search
|
# 13. run beam search
|
||||||
result = self.group_beam_search(
|
result = self._group_beam_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
beam_scorer,
|
beam_scorer,
|
||||||
logits_processor=prepared_logits_processor,
|
logits_processor=prepared_logits_processor,
|
||||||
@@ -1777,7 +1779,7 @@ class GenerationMixin:
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
# 13. run beam search
|
# 13. run beam search
|
||||||
result = self.constrained_beam_search(
|
result = self._constrained_beam_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
constrained_beam_scorer=constrained_beam_scorer,
|
constrained_beam_scorer=constrained_beam_scorer,
|
||||||
logits_processor=prepared_logits_processor,
|
logits_processor=prepared_logits_processor,
|
||||||
@@ -1801,8 +1803,15 @@ class GenerationMixin:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def contrastive_search(self, *args, **kwargs):
|
||||||
|
logger.warning_once(
|
||||||
|
"Calling `contrastive_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
|
||||||
|
"custom generation loop instead.",
|
||||||
|
)
|
||||||
|
return self._contrastive_search(*args, **kwargs)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def contrastive_search(
|
def _contrastive_search(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
top_k: Optional[int] = 1,
|
top_k: Optional[int] = 1,
|
||||||
@@ -1828,7 +1837,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
<Tip warning={true}>
|
<Tip warning={true}>
|
||||||
|
|
||||||
In most cases, you do not need to call [`~generation.GenerationMixin.contrastive_search`] directly. Use
|
In most cases, you do not need to call [`~generation.GenerationMixin._contrastive_search`] directly. Use
|
||||||
generate() instead. For an overview of generation strategies and code examples, check the [following
|
generate() instead. For an overview of generation strategies and code examples, check the [following
|
||||||
guide](../generation_strategies).
|
guide](../generation_strategies).
|
||||||
|
|
||||||
@@ -1902,7 +1911,7 @@ class GenerationMixin:
|
|||||||
>>> input_prompt = "DeepMind Company is"
|
>>> input_prompt = "DeepMind Company is"
|
||||||
>>> input_ids = tokenizer(input_prompt, return_tensors="pt")
|
>>> input_ids = tokenizer(input_prompt, return_tensors="pt")
|
||||||
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=64)])
|
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=64)])
|
||||||
>>> outputs = model.contrastive_search(
|
>>> outputs = model._contrastive_search(
|
||||||
... **input_ids, penalty_alpha=0.6, top_k=4, stopping_criteria=stopping_criteria
|
... **input_ids, penalty_alpha=0.6, top_k=4, stopping_criteria=stopping_criteria
|
||||||
... )
|
... )
|
||||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
@@ -2243,7 +2252,14 @@ class GenerationMixin:
|
|||||||
else:
|
else:
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
def greedy_search(
|
def greedy_search(self, *args, **kwargs):
|
||||||
|
logger.warning_once(
|
||||||
|
"Calling `greedy_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
|
||||||
|
"custom generation loop instead.",
|
||||||
|
)
|
||||||
|
return self._greedy_search(*args, **kwargs)
|
||||||
|
|
||||||
|
def _greedy_search(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
@@ -2266,7 +2282,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
<Tip warning={true}>
|
<Tip warning={true}>
|
||||||
|
|
||||||
In most cases, you do not need to call [`~generation.GenerationMixin.greedy_search`] directly. Use generate()
|
In most cases, you do not need to call [`~generation.GenerationMixin._greedy_search`] directly. Use generate()
|
||||||
instead. For an overview of generation strategies and code examples, check the [following
|
instead. For an overview of generation strategies and code examples, check the [following
|
||||||
guide](../generation_strategies).
|
guide](../generation_strategies).
|
||||||
|
|
||||||
@@ -2348,7 +2364,7 @@ class GenerationMixin:
|
|||||||
... )
|
... )
|
||||||
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
|
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
|
||||||
|
|
||||||
>>> outputs = model.greedy_search(
|
>>> outputs = model._greedy_search(
|
||||||
... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
|
... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
|
||||||
... )
|
... )
|
||||||
|
|
||||||
@@ -2514,7 +2530,14 @@ class GenerationMixin:
|
|||||||
else:
|
else:
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
def sample(
|
def sample(self, *args, **kwargs):
|
||||||
|
logger.warning_once(
|
||||||
|
"Calling `sample` directly is deprecated and will be removed in v4.41. Use `generate` or a "
|
||||||
|
"custom generation loop instead.",
|
||||||
|
)
|
||||||
|
return self._sample(*args, **kwargs)
|
||||||
|
|
||||||
|
def _sample(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
@@ -2538,7 +2561,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
<Tip warning={true}>
|
<Tip warning={true}>
|
||||||
|
|
||||||
In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead.
|
In most cases, you do not need to call [`~generation.GenerationMixin._sample`] directly. Use generate() instead.
|
||||||
For an overview of generation strategies and code examples, check the [following
|
For an overview of generation strategies and code examples, check the [following
|
||||||
guide](../generation_strategies).
|
guide](../generation_strategies).
|
||||||
|
|
||||||
@@ -2635,7 +2658,7 @@ class GenerationMixin:
|
|||||||
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
|
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
|
||||||
|
|
||||||
>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
|
>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
|
||||||
>>> outputs = model.sample(
|
>>> outputs = model._sample(
|
||||||
... input_ids,
|
... input_ids,
|
||||||
... logits_processor=logits_processor,
|
... logits_processor=logits_processor,
|
||||||
... logits_warper=logits_warper,
|
... logits_warper=logits_warper,
|
||||||
@@ -2832,7 +2855,14 @@ class GenerationMixin:
|
|||||||
past_key_values.reorder_cache(beam_idx)
|
past_key_values.reorder_cache(beam_idx)
|
||||||
return past_key_values
|
return past_key_values
|
||||||
|
|
||||||
def beam_search(
|
def beam_search(self, *args, **kwargs):
|
||||||
|
logger.warning_once(
|
||||||
|
"Calling `beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
|
||||||
|
"custom generation loop instead.",
|
||||||
|
)
|
||||||
|
return self._beam_search(*args, **kwargs)
|
||||||
|
|
||||||
|
def _beam_search(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
beam_scorer: BeamScorer,
|
beam_scorer: BeamScorer,
|
||||||
@@ -2856,7 +2886,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
<Tip warning={true}>
|
<Tip warning={true}>
|
||||||
|
|
||||||
In most cases, you do not need to call [`~generation.GenerationMixin.beam_search`] directly. Use generate()
|
In most cases, you do not need to call [`~generation.GenerationMixin._beam_search`] directly. Use generate()
|
||||||
instead. For an overview of generation strategies and code examples, check the [following
|
instead. For an overview of generation strategies and code examples, check the [following
|
||||||
guide](../generation_strategies).
|
guide](../generation_strategies).
|
||||||
|
|
||||||
@@ -2958,7 +2988,7 @@ class GenerationMixin:
|
|||||||
... ]
|
... ]
|
||||||
... )
|
... )
|
||||||
|
|
||||||
>>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
|
>>> outputs = model._beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
|
||||||
|
|
||||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
['Wie alt bist du?']
|
['Wie alt bist du?']
|
||||||
@@ -3214,7 +3244,14 @@ class GenerationMixin:
|
|||||||
else:
|
else:
|
||||||
return sequence_outputs["sequences"]
|
return sequence_outputs["sequences"]
|
||||||
|
|
||||||
def beam_sample(
|
def beam_sample(self, *args, **kwargs):
|
||||||
|
logger.warning_once(
|
||||||
|
"Calling `beam_sample` directly is deprecated and will be removed in v4.41. Use `generate` or a "
|
||||||
|
"custom generation loop instead.",
|
||||||
|
)
|
||||||
|
return self._beam_sample(*args, **kwargs)
|
||||||
|
|
||||||
|
def _beam_sample(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
beam_scorer: BeamScorer,
|
beam_scorer: BeamScorer,
|
||||||
@@ -3238,7 +3275,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
<Tip warning={true}>
|
<Tip warning={true}>
|
||||||
|
|
||||||
In most cases, you do not need to call [`~generation.GenerationMixin.beam_sample`] directly. Use generate()
|
In most cases, you do not need to call [`~generation.GenerationMixin._beam_sample`] directly. Use generate()
|
||||||
instead. For an overview of generation strategies and code examples, check the [following
|
instead. For an overview of generation strategies and code examples, check the [following
|
||||||
guide](../generation_strategies).
|
guide](../generation_strategies).
|
||||||
|
|
||||||
@@ -3346,7 +3383,7 @@ class GenerationMixin:
|
|||||||
... ]
|
... ]
|
||||||
... )
|
... )
|
||||||
|
|
||||||
>>> outputs = model.beam_sample(
|
>>> outputs = model._beam_sample(
|
||||||
... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs
|
... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs
|
||||||
... )
|
... )
|
||||||
|
|
||||||
@@ -3561,7 +3598,14 @@ class GenerationMixin:
|
|||||||
else:
|
else:
|
||||||
return sequence_outputs["sequences"]
|
return sequence_outputs["sequences"]
|
||||||
|
|
||||||
def group_beam_search(
|
def group_beam_search(self, *args, **kwargs):
|
||||||
|
logger.warning_once(
|
||||||
|
"Calling `group_beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
|
||||||
|
"custom generation loop instead.",
|
||||||
|
)
|
||||||
|
return self._group_beam_search(*args, **kwargs)
|
||||||
|
|
||||||
|
def _group_beam_search(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
beam_scorer: BeamScorer,
|
beam_scorer: BeamScorer,
|
||||||
@@ -3584,7 +3628,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
<Tip warning={true}>
|
<Tip warning={true}>
|
||||||
|
|
||||||
In most cases, you do not need to call [`~generation.GenerationMixin.group_beam_search`] directly. Use
|
In most cases, you do not need to call [`~generation.GenerationMixin._group_beam_search`] directly. Use
|
||||||
generate() instead. For an overview of generation strategies and code examples, check the [following
|
generate() instead. For an overview of generation strategies and code examples, check the [following
|
||||||
guide](../generation_strategies).
|
guide](../generation_strategies).
|
||||||
|
|
||||||
@@ -3686,7 +3730,7 @@ class GenerationMixin:
|
|||||||
... ]
|
... ]
|
||||||
... )
|
... )
|
||||||
|
|
||||||
>>> outputs = model.group_beam_search(
|
>>> outputs = model._group_beam_search(
|
||||||
... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs
|
... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs
|
||||||
... )
|
... )
|
||||||
|
|
||||||
@@ -3958,7 +4002,14 @@ class GenerationMixin:
|
|||||||
else:
|
else:
|
||||||
return sequence_outputs["sequences"]
|
return sequence_outputs["sequences"]
|
||||||
|
|
||||||
def constrained_beam_search(
|
def constrained_beam_search(self, *args, **kwargs):
|
||||||
|
logger.warning_once(
|
||||||
|
"Calling `constrained_beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
|
||||||
|
"custom generation loop instead.",
|
||||||
|
)
|
||||||
|
return self._constrained_beam_search(*args, **kwargs)
|
||||||
|
|
||||||
|
def _constrained_beam_search(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
constrained_beam_scorer: ConstrainedBeamSearchScorer,
|
constrained_beam_scorer: ConstrainedBeamSearchScorer,
|
||||||
@@ -3981,7 +4032,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
<Tip warning={true}>
|
<Tip warning={true}>
|
||||||
|
|
||||||
In most cases, you do not need to call [`~generation.GenerationMixin.constrained_beam_search`] directly. Use
|
In most cases, you do not need to call [`~generation.GenerationMixin._constrained_beam_search`] directly. Use
|
||||||
generate() instead. For an overview of generation strategies and code examples, check the [following
|
generate() instead. For an overview of generation strategies and code examples, check the [following
|
||||||
guide](../generation_strategies).
|
guide](../generation_strategies).
|
||||||
|
|
||||||
@@ -4088,7 +4139,7 @@ class GenerationMixin:
|
|||||||
... ]
|
... ]
|
||||||
... )
|
... )
|
||||||
|
|
||||||
>>> outputs = model.constrained_beam_search(
|
>>> outputs = model._constrained_beam_search(
|
||||||
... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
|
... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
|
||||||
... )
|
... )
|
||||||
|
|
||||||
@@ -4311,7 +4362,14 @@ class GenerationMixin:
|
|||||||
else:
|
else:
|
||||||
return sequence_outputs["sequences"]
|
return sequence_outputs["sequences"]
|
||||||
|
|
||||||
def assisted_decoding(
|
def assisted_decoding(self, *args, **kwargs):
|
||||||
|
logger.warning_once(
|
||||||
|
"Calling `_assisted_decoding` directly is deprecated and will be removed in v4.41. Use `generate` or a "
|
||||||
|
"custom generation loop instead.",
|
||||||
|
)
|
||||||
|
return self._assisted_decoding(*args, **kwargs)
|
||||||
|
|
||||||
|
def _assisted_decoding(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
candidate_generator: Optional["CandidateGenerator"] = None,
|
candidate_generator: Optional["CandidateGenerator"] = None,
|
||||||
@@ -4338,7 +4396,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
<Tip warning={true}>
|
<Tip warning={true}>
|
||||||
|
|
||||||
In most cases, you do not need to call [`~generation.GenerationMixin.candidate_decoding`] directly. Use
|
In most cases, you do not need to call [`~generation.GenerationMixin._assisted_decoding`] directly. Use
|
||||||
generate() instead. For an overview of generation strategies and code examples, check the [following
|
generate() instead. For an overview of generation strategies and code examples, check the [following
|
||||||
guide](../generation_strategies).
|
guide](../generation_strategies).
|
||||||
|
|
||||||
@@ -4429,7 +4487,7 @@ class GenerationMixin:
|
|||||||
... logits_processor=logits_processor,
|
... logits_processor=logits_processor,
|
||||||
... model_kwargs={},
|
... model_kwargs={},
|
||||||
... )
|
... )
|
||||||
>>> outputs = model.assisted_decoding(
|
>>> outputs = model._assisted_decoding(
|
||||||
... input_ids,
|
... input_ids,
|
||||||
... candidate_generator=candidate_generator,
|
... candidate_generator=candidate_generator,
|
||||||
... logits_processor=logits_processor,
|
... logits_processor=logits_processor,
|
||||||
|
|||||||
@@ -1336,7 +1336,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 11. run greedy search
|
# 11. run greedy search
|
||||||
outputs = self.greedy_search(
|
outputs = self._greedy_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=stopping_criteria,
|
||||||
@@ -1361,7 +1361,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 12. run sample
|
# 12. run sample
|
||||||
outputs = self.sample(
|
outputs = self._sample(
|
||||||
input_ids,
|
input_ids,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
logits_warper=logits_warper,
|
logits_warper=logits_warper,
|
||||||
@@ -2402,7 +2402,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 11. run greedy search
|
# 11. run greedy search
|
||||||
outputs = self.greedy_search(
|
outputs = self._greedy_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
stopping_criteria=stopping_criteria,
|
stopping_criteria=stopping_criteria,
|
||||||
@@ -2428,7 +2428,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 12. run sample
|
# 12. run sample
|
||||||
outputs = self.sample(
|
outputs = self._sample(
|
||||||
input_ids,
|
input_ids,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
logits_warper=logits_warper,
|
logits_warper=logits_warper,
|
||||||
|
|||||||
@@ -1539,7 +1539,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
|
f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
|
||||||
" greedy search."
|
" greedy search."
|
||||||
)
|
)
|
||||||
return self.greedy_search(
|
return self._greedy_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
logits_processor=pre_processor,
|
logits_processor=pre_processor,
|
||||||
max_length=generation_config.max_length,
|
max_length=generation_config.max_length,
|
||||||
@@ -1559,7 +1559,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
num_beam_hyps_to_keep=generation_config.num_return_sequences,
|
num_beam_hyps_to_keep=generation_config.num_return_sequences,
|
||||||
max_length=generation_config.max_length,
|
max_length=generation_config.max_length,
|
||||||
)
|
)
|
||||||
return self.beam_search(
|
return self._beam_search(
|
||||||
input_ids,
|
input_ids,
|
||||||
beam_scorer,
|
beam_scorer,
|
||||||
logits_processor=pre_processor,
|
logits_processor=pre_processor,
|
||||||
|
|||||||
Reference in New Issue
Block a user