From 5c17918fe4cda80dae5b7ec8f0b2d23a813c4a05 Mon Sep 17 00:00:00 2001 From: Robert Dargavel Smith Date: Fri, 3 Jun 2022 10:08:35 +0100 Subject: [PATCH] Allow from transformers import TypicalLogitsWarper (#17477) * Allow from transformers import TypicalLogitsWarper * Added TypicalLogitsWarper * Allow from transformers import TypicalLogitsWarper * Allow from transformers import TypicalLogitsWarper * Allow from transformers import TypicalLogitsWarper * Allow from transformers import TypicalLogitsWarper Added TypicalLogitsWarper Allow from transformers import TypicalLogitsWarper Allow from transformers import TypicalLogitsWarper Allow from transformers import TypicalLogitsWarper --- docs/source/en/internal/generation_utils.mdx | 3 +++ src/transformers/__init__.py | 2 ++ src/transformers/utils/dummy_pt_objects.py | 7 +++++++ 3 files changed, 12 insertions(+) diff --git a/docs/source/en/internal/generation_utils.mdx b/docs/source/en/internal/generation_utils.mdx index 5a717edb98..bdb6c7c59c 100644 --- a/docs/source/en/internal/generation_utils.mdx +++ b/docs/source/en/internal/generation_utils.mdx @@ -127,6 +127,9 @@ generation. [[autodoc]] TopKLogitsWarper - __call__ +[[autodoc]] TypicalLogitsWarper + - __call__ + [[autodoc]] NoRepeatNGramLogitsProcessor - __call__ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b40d95d4d9..cfbb4f0dbd 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -703,6 +703,7 @@ else: "TemperatureLogitsWarper", "TopKLogitsWarper", "TopPLogitsWarper", + "TypicalLogitsWarper", ] _import_structure["generation_stopping_criteria"] = [ "MaxLengthCriteria", @@ -3218,6 +3219,7 @@ if TYPE_CHECKING: TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, + TypicalLogitsWarper, ) from .generation_stopping_criteria import ( MaxLengthCriteria, diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 4b7e06e635..5459ecfbb2 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -234,6 +234,13 @@ class TopPLogitsWarper(metaclass=DummyObject): requires_backends(self, ["torch"]) +class TypicalLogitsWarper(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class MaxLengthCriteria(metaclass=DummyObject): _backends = ["torch"]