Allow from transformers import TypicalLogitsWarper (#17477)

* Allow from transformers import TypicalLogitsWarper

* Added TypicalLogitsWarper

* Allow from transformers import TypicalLogitsWarper

* Allow from transformers import TypicalLogitsWarper

* Allow from transformers import TypicalLogitsWarper

* Allow from transformers import TypicalLogitsWarper

Added TypicalLogitsWarper

Allow from transformers import TypicalLogitsWarper

Allow from transformers import TypicalLogitsWarper

Allow from transformers import TypicalLogitsWarper
This commit is contained in:
Robert Dargavel Smith
2022-06-03 10:08:35 +01:00
committed by GitHub
parent 607acd4fbd
commit 5c17918fe4
3 changed files with 12 additions and 0 deletions

View File

@@ -127,6 +127,9 @@ generation.
[[autodoc]] TopKLogitsWarper [[autodoc]] TopKLogitsWarper
- __call__ - __call__
[[autodoc]] TypicalLogitsWarper
- __call__
[[autodoc]] NoRepeatNGramLogitsProcessor [[autodoc]] NoRepeatNGramLogitsProcessor
- __call__ - __call__

View File

@@ -703,6 +703,7 @@ else:
"TemperatureLogitsWarper", "TemperatureLogitsWarper",
"TopKLogitsWarper", "TopKLogitsWarper",
"TopPLogitsWarper", "TopPLogitsWarper",
"TypicalLogitsWarper",
] ]
_import_structure["generation_stopping_criteria"] = [ _import_structure["generation_stopping_criteria"] = [
"MaxLengthCriteria", "MaxLengthCriteria",
@@ -3218,6 +3219,7 @@ if TYPE_CHECKING:
TemperatureLogitsWarper, TemperatureLogitsWarper,
TopKLogitsWarper, TopKLogitsWarper,
TopPLogitsWarper, TopPLogitsWarper,
TypicalLogitsWarper,
) )
from .generation_stopping_criteria import ( from .generation_stopping_criteria import (
MaxLengthCriteria, MaxLengthCriteria,

View File

@@ -234,6 +234,13 @@ class TopPLogitsWarper(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class TypicalLogitsWarper(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MaxLengthCriteria(metaclass=DummyObject): class MaxLengthCriteria(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]