From 66abe1395157f8cb18830166625d886671eeb2fb Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 2 May 2024 15:20:04 +0100 Subject: [PATCH] Docs: add missing `StoppingCriteria` autodocs (#30617) * add missing docstrings to docs * Update src/transformers/generation/stopping_criteria.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- docs/source/en/internal/generation_utils.md | 6 ++++++ src/transformers/__init__.py | 4 ++++ src/transformers/generation/stopping_criteria.py | 2 +- src/transformers/utils/dummy_pt_objects.py | 14 ++++++++++++++ 4 files changed, 25 insertions(+), 1 deletion(-) diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index e6872efe73..58c1d4478b 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -310,6 +310,12 @@ A [`StoppingCriteria`] can be used to change when to stop generation (other than [[autodoc]] MaxTimeCriteria - __call__ +[[autodoc]] StopStringCriteria + - __call__ + +[[autodoc]] EosTokenCriteria + - __call__ + ## Constraints A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output. Please note that this is exclusively available to our PyTorch implementations. diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 083c7f031a..53a087468e 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1419,6 +1419,7 @@ else: "DisjunctiveConstraint", "EncoderNoRepeatNGramLogitsProcessor", "EncoderRepetitionPenaltyLogitsProcessor", + "EosTokenCriteria", "EpsilonLogitsWarper", "EtaLogitsWarper", "ExponentialDecayLengthPenalty", @@ -1444,6 +1445,7 @@ else: "SequenceBiasLogitsProcessor", "StoppingCriteria", "StoppingCriteriaList", + "StopStringCriteria", "SuppressTokensAtBeginLogitsProcessor", "SuppressTokensLogitsProcessor", "TemperatureLogitsWarper", @@ -6372,6 +6374,7 @@ if TYPE_CHECKING: DisjunctiveConstraint, EncoderNoRepeatNGramLogitsProcessor, EncoderRepetitionPenaltyLogitsProcessor, + EosTokenCriteria, EpsilonLogitsWarper, EtaLogitsWarper, ExponentialDecayLengthPenalty, @@ -6397,6 +6400,7 @@ if TYPE_CHECKING: SequenceBiasLogitsProcessor, StoppingCriteria, StoppingCriteriaList, + StopStringCriteria, SuppressTokensAtBeginLogitsProcessor, SuppressTokensLogitsProcessor, TemperatureLogitsWarper, diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 44c040ca6a..48392400c4 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -98,7 +98,7 @@ class MaxNewTokensCriteria(StoppingCriteria): def __init__(self, start_length: int, max_new_tokens: int): warnings.warn( - "The class `MaxNewTokensCriteria` is deprecated. " + "The class `MaxNewTokensCriteria` is deprecated and will be removed in v4.43. " f"Please use `MaxLengthCriteria(max_length={start_length + max_new_tokens})` " "with `max_length = start_length + max_new_tokens` instead.", FutureWarning, diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index f91bbbe4fc..d5c64cc141 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -177,6 +177,13 @@ class EncoderRepetitionPenaltyLogitsProcessor(metaclass=DummyObject): requires_backends(self, ["torch"]) +class EosTokenCriteria(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class EpsilonLogitsWarper(metaclass=DummyObject): _backends = ["torch"] @@ -352,6 +359,13 @@ class StoppingCriteriaList(metaclass=DummyObject): requires_backends(self, ["torch"]) +class StopStringCriteria(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class SuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject): _backends = ["torch"]