Docs: add missing StoppingCriteria autodocs (#30617)

* add missing docstrings to docs

* Update src/transformers/generation/stopping_criteria.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Joao Gante
2024-05-02 15:20:04 +01:00
committed by GitHub
parent aa55ff44a2
commit 66abe13951
4 changed files with 25 additions and 1 deletions

View File

@@ -310,6 +310,12 @@ A [`StoppingCriteria`] can be used to change when to stop generation (other than
[[autodoc]] MaxTimeCriteria [[autodoc]] MaxTimeCriteria
- __call__ - __call__
[[autodoc]] StopStringCriteria
- __call__
[[autodoc]] EosTokenCriteria
- __call__
## Constraints ## Constraints
A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output. Please note that this is exclusively available to our PyTorch implementations. A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output. Please note that this is exclusively available to our PyTorch implementations.

View File

@@ -1419,6 +1419,7 @@ else:
"DisjunctiveConstraint", "DisjunctiveConstraint",
"EncoderNoRepeatNGramLogitsProcessor", "EncoderNoRepeatNGramLogitsProcessor",
"EncoderRepetitionPenaltyLogitsProcessor", "EncoderRepetitionPenaltyLogitsProcessor",
"EosTokenCriteria",
"EpsilonLogitsWarper", "EpsilonLogitsWarper",
"EtaLogitsWarper", "EtaLogitsWarper",
"ExponentialDecayLengthPenalty", "ExponentialDecayLengthPenalty",
@@ -1444,6 +1445,7 @@ else:
"SequenceBiasLogitsProcessor", "SequenceBiasLogitsProcessor",
"StoppingCriteria", "StoppingCriteria",
"StoppingCriteriaList", "StoppingCriteriaList",
"StopStringCriteria",
"SuppressTokensAtBeginLogitsProcessor", "SuppressTokensAtBeginLogitsProcessor",
"SuppressTokensLogitsProcessor", "SuppressTokensLogitsProcessor",
"TemperatureLogitsWarper", "TemperatureLogitsWarper",
@@ -6372,6 +6374,7 @@ if TYPE_CHECKING:
DisjunctiveConstraint, DisjunctiveConstraint,
EncoderNoRepeatNGramLogitsProcessor, EncoderNoRepeatNGramLogitsProcessor,
EncoderRepetitionPenaltyLogitsProcessor, EncoderRepetitionPenaltyLogitsProcessor,
EosTokenCriteria,
EpsilonLogitsWarper, EpsilonLogitsWarper,
EtaLogitsWarper, EtaLogitsWarper,
ExponentialDecayLengthPenalty, ExponentialDecayLengthPenalty,
@@ -6397,6 +6400,7 @@ if TYPE_CHECKING:
SequenceBiasLogitsProcessor, SequenceBiasLogitsProcessor,
StoppingCriteria, StoppingCriteria,
StoppingCriteriaList, StoppingCriteriaList,
StopStringCriteria,
SuppressTokensAtBeginLogitsProcessor, SuppressTokensAtBeginLogitsProcessor,
SuppressTokensLogitsProcessor, SuppressTokensLogitsProcessor,
TemperatureLogitsWarper, TemperatureLogitsWarper,

View File

@@ -98,7 +98,7 @@ class MaxNewTokensCriteria(StoppingCriteria):
def __init__(self, start_length: int, max_new_tokens: int): def __init__(self, start_length: int, max_new_tokens: int):
warnings.warn( warnings.warn(
"The class `MaxNewTokensCriteria` is deprecated. " "The class `MaxNewTokensCriteria` is deprecated and will be removed in v4.43. "
f"Please use `MaxLengthCriteria(max_length={start_length + max_new_tokens})` " f"Please use `MaxLengthCriteria(max_length={start_length + max_new_tokens})` "
"with `max_length = start_length + max_new_tokens` instead.", "with `max_length = start_length + max_new_tokens` instead.",
FutureWarning, FutureWarning,

View File

@@ -177,6 +177,13 @@ class EncoderRepetitionPenaltyLogitsProcessor(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class EosTokenCriteria(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class EpsilonLogitsWarper(metaclass=DummyObject): class EpsilonLogitsWarper(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
@@ -352,6 +359,13 @@ class StoppingCriteriaList(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class StopStringCriteria(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class SuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject): class SuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]