Generate: add missing logits processors docs (#25653)
This commit is contained in:
@@ -75,39 +75,104 @@ values. Here, for instance, it has two keys that are `sequences` and `scores`.
|
|||||||
We document here all output types.
|
We document here all output types.
|
||||||
|
|
||||||
|
|
||||||
### GreedySearchOutput
|
### PyTorch
|
||||||
|
|
||||||
[[autodoc]] generation.GreedySearchDecoderOnlyOutput
|
|
||||||
|
|
||||||
[[autodoc]] generation.GreedySearchEncoderDecoderOutput
|
[[autodoc]] generation.GreedySearchEncoderDecoderOutput
|
||||||
|
|
||||||
[[autodoc]] generation.FlaxGreedySearchOutput
|
[[autodoc]] generation.GreedySearchDecoderOnlyOutput
|
||||||
|
|
||||||
### SampleOutput
|
|
||||||
|
|
||||||
[[autodoc]] generation.SampleDecoderOnlyOutput
|
|
||||||
|
|
||||||
[[autodoc]] generation.SampleEncoderDecoderOutput
|
[[autodoc]] generation.SampleEncoderDecoderOutput
|
||||||
|
|
||||||
[[autodoc]] generation.FlaxSampleOutput
|
[[autodoc]] generation.SampleDecoderOnlyOutput
|
||||||
|
|
||||||
### BeamSearchOutput
|
|
||||||
|
|
||||||
[[autodoc]] generation.BeamSearchDecoderOnlyOutput
|
|
||||||
|
|
||||||
[[autodoc]] generation.BeamSearchEncoderDecoderOutput
|
[[autodoc]] generation.BeamSearchEncoderDecoderOutput
|
||||||
|
|
||||||
### BeamSampleOutput
|
[[autodoc]] generation.BeamSearchDecoderOnlyOutput
|
||||||
|
|
||||||
|
[[autodoc]] generation.BeamSampleEncoderDecoderOutput
|
||||||
|
|
||||||
[[autodoc]] generation.BeamSampleDecoderOnlyOutput
|
[[autodoc]] generation.BeamSampleDecoderOnlyOutput
|
||||||
|
|
||||||
[[autodoc]] generation.BeamSampleEncoderDecoderOutput
|
[[autodoc]] generation.ContrastiveSearchEncoderDecoderOutput
|
||||||
|
|
||||||
|
[[autodoc]] generation.ContrastiveSearchDecoderOnlyOutput
|
||||||
|
|
||||||
|
### TensorFlow
|
||||||
|
|
||||||
|
[[autodoc]] generation.TFGreedySearchEncoderDecoderOutput
|
||||||
|
|
||||||
|
[[autodoc]] generation.TFGreedySearchDecoderOnlyOutput
|
||||||
|
|
||||||
|
[[autodoc]] generation.TFSampleEncoderDecoderOutput
|
||||||
|
|
||||||
|
[[autodoc]] generation.TFSampleDecoderOnlyOutput
|
||||||
|
|
||||||
|
[[autodoc]] generation.TFBeamSearchEncoderDecoderOutput
|
||||||
|
|
||||||
|
[[autodoc]] generation.TFBeamSearchDecoderOnlyOutput
|
||||||
|
|
||||||
|
[[autodoc]] generation.TFBeamSampleEncoderDecoderOutput
|
||||||
|
|
||||||
|
[[autodoc]] generation.TFBeamSampleDecoderOnlyOutput
|
||||||
|
|
||||||
|
[[autodoc]] generation.TFContrastiveSearchEncoderDecoderOutput
|
||||||
|
|
||||||
|
[[autodoc]] generation.TFContrastiveSearchDecoderOnlyOutput
|
||||||
|
|
||||||
|
### FLAX
|
||||||
|
|
||||||
|
[[autodoc]] generation.FlaxSampleOutput
|
||||||
|
|
||||||
|
[[autodoc]] generation.FlaxGreedySearchOutput
|
||||||
|
|
||||||
|
[[autodoc]] generation.FlaxBeamSearchOutput
|
||||||
|
|
||||||
## LogitsProcessor
|
## LogitsProcessor
|
||||||
|
|
||||||
A [`LogitsProcessor`] can be used to modify the prediction scores of a language model head for
|
A [`LogitsProcessor`] can be used to modify the prediction scores of a language model head for
|
||||||
generation.
|
generation.
|
||||||
|
|
||||||
|
### PyTorch
|
||||||
|
|
||||||
|
[[autodoc]] AlternatingCodebooksLogitsProcessor
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] ClassifierFreeGuidanceLogitsProcessor
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] EncoderNoRepeatNGramLogitsProcessor
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] EncoderRepetitionPenaltyLogitsProcessor
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] EpsilonLogitsWarper
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] EtaLogitsWarper
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] ExponentialDecayLengthPenalty
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] ForcedBOSTokenLogitsProcessor
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] ForcedEOSTokenLogitsProcessor
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] ForceTokensLogitsProcessor
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] HammingDiversityLogitsProcessor
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] InfNanRemoveLogitsProcessor
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] LogitNormalization
|
||||||
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] LogitsProcessor
|
[[autodoc]] LogitsProcessor
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
@@ -123,43 +188,54 @@ generation.
|
|||||||
[[autodoc]] MinNewTokensLengthLogitsProcessor
|
[[autodoc]] MinNewTokensLengthLogitsProcessor
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] TemperatureLogitsWarper
|
[[autodoc]] NoBadWordsLogitsProcessor
|
||||||
- __call__
|
|
||||||
|
|
||||||
[[autodoc]] RepetitionPenaltyLogitsProcessor
|
|
||||||
- __call__
|
|
||||||
|
|
||||||
[[autodoc]] TopPLogitsWarper
|
|
||||||
- __call__
|
|
||||||
|
|
||||||
[[autodoc]] TopKLogitsWarper
|
|
||||||
- __call__
|
|
||||||
|
|
||||||
[[autodoc]] TypicalLogitsWarper
|
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] NoRepeatNGramLogitsProcessor
|
[[autodoc]] NoRepeatNGramLogitsProcessor
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] SequenceBiasLogitsProcessor
|
|
||||||
- __call__
|
|
||||||
|
|
||||||
[[autodoc]] NoBadWordsLogitsProcessor
|
|
||||||
- __call__
|
|
||||||
|
|
||||||
[[autodoc]] PrefixConstrainedLogitsProcessor
|
[[autodoc]] PrefixConstrainedLogitsProcessor
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] HammingDiversityLogitsProcessor
|
[[autodoc]] RepetitionPenaltyLogitsProcessor
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] ForcedBOSTokenLogitsProcessor
|
[[autodoc]] SequenceBiasLogitsProcessor
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] ForcedEOSTokenLogitsProcessor
|
[[autodoc]] SuppressTokensAtBeginLogitsProcessor
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] InfNanRemoveLogitsProcessor
|
[[autodoc]] SuppressTokensLogitsProcessor
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] TemperatureLogitsWarper
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] TopKLogitsWarper
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] TopPLogitsWarper
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] TypicalLogitsWarper
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] UnbatchedClassifierFreeGuidanceLogitsProcessor
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] WhisperTimeStampLogitsProcessor
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
### TensorFlow
|
||||||
|
|
||||||
|
[[autodoc]] TFForcedBOSTokenLogitsProcessor
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] TFForcedEOSTokenLogitsProcessor
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] TFForceTokensLogitsProcessor
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] TFLogitsProcessor
|
[[autodoc]] TFLogitsProcessor
|
||||||
@@ -171,15 +247,6 @@ generation.
|
|||||||
[[autodoc]] TFLogitsWarper
|
[[autodoc]] TFLogitsWarper
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] TFTemperatureLogitsWarper
|
|
||||||
- __call__
|
|
||||||
|
|
||||||
[[autodoc]] TFTopPLogitsWarper
|
|
||||||
- __call__
|
|
||||||
|
|
||||||
[[autodoc]] TFTopKLogitsWarper
|
|
||||||
- __call__
|
|
||||||
|
|
||||||
[[autodoc]] TFMinLengthLogitsProcessor
|
[[autodoc]] TFMinLengthLogitsProcessor
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
@@ -192,10 +259,30 @@ generation.
|
|||||||
[[autodoc]] TFRepetitionPenaltyLogitsProcessor
|
[[autodoc]] TFRepetitionPenaltyLogitsProcessor
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] TFForcedBOSTokenLogitsProcessor
|
[[autodoc]] TFSuppressTokensAtBeginLogitsProcessor
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] TFForcedEOSTokenLogitsProcessor
|
[[autodoc]] TFSuppressTokensLogitsProcessor
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] TFTemperatureLogitsWarper
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] TFTopKLogitsWarper
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] TFTopPLogitsWarper
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
### FLAX
|
||||||
|
|
||||||
|
[[autodoc]] FlaxForcedBOSTokenLogitsProcessor
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] FlaxForcedEOSTokenLogitsProcessor
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] FlaxForceTokensLogitsProcessor
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] FlaxLogitsProcessor
|
[[autodoc]] FlaxLogitsProcessor
|
||||||
@@ -207,27 +294,30 @@ generation.
|
|||||||
[[autodoc]] FlaxLogitsWarper
|
[[autodoc]] FlaxLogitsWarper
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] FlaxTemperatureLogitsWarper
|
[[autodoc]] FlaxMinLengthLogitsProcessor
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] FlaxTopPLogitsWarper
|
[[autodoc]] FlaxSuppressTokensAtBeginLogitsProcessor
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] FlaxSuppressTokensLogitsProcessor
|
||||||
|
- __call__
|
||||||
|
|
||||||
|
[[autodoc]] FlaxTemperatureLogitsWarper
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] FlaxTopKLogitsWarper
|
[[autodoc]] FlaxTopKLogitsWarper
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] FlaxForcedBOSTokenLogitsProcessor
|
[[autodoc]] FlaxTopPLogitsWarper
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
[[autodoc]] FlaxForcedEOSTokenLogitsProcessor
|
[[autodoc]] FlaxWhisperTimeStampLogitsProcessor
|
||||||
- __call__
|
|
||||||
|
|
||||||
[[autodoc]] FlaxMinLengthLogitsProcessor
|
|
||||||
- __call__
|
- __call__
|
||||||
|
|
||||||
## StoppingCriteria
|
## StoppingCriteria
|
||||||
|
|
||||||
A [`StoppingCriteria`] can be used to change when to stop generation (other than EOS token).
|
A [`StoppingCriteria`] can be used to change when to stop generation (other than EOS token). Please note that this is exclusivelly available to our PyTorch implementations.
|
||||||
|
|
||||||
[[autodoc]] StoppingCriteria
|
[[autodoc]] StoppingCriteria
|
||||||
- __call__
|
- __call__
|
||||||
@@ -243,7 +333,7 @@ A [`StoppingCriteria`] can be used to change when to stop generation (other than
|
|||||||
|
|
||||||
## Constraints
|
## Constraints
|
||||||
|
|
||||||
A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output.
|
A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output. Please note that this is exclusivelly available to our PyTorch implementations.
|
||||||
|
|
||||||
[[autodoc]] Constraint
|
[[autodoc]] Constraint
|
||||||
|
|
||||||
|
|||||||
@@ -995,17 +995,26 @@ else:
|
|||||||
_import_structure["deepspeed"] = []
|
_import_structure["deepspeed"] = []
|
||||||
_import_structure["generation"].extend(
|
_import_structure["generation"].extend(
|
||||||
[
|
[
|
||||||
|
"AlternatingCodebooksLogitsProcessor",
|
||||||
"BeamScorer",
|
"BeamScorer",
|
||||||
"BeamSearchScorer",
|
"BeamSearchScorer",
|
||||||
|
"ClassifierFreeGuidanceLogitsProcessor",
|
||||||
"ConstrainedBeamSearchScorer",
|
"ConstrainedBeamSearchScorer",
|
||||||
"Constraint",
|
"Constraint",
|
||||||
"ConstraintListState",
|
"ConstraintListState",
|
||||||
"DisjunctiveConstraint",
|
"DisjunctiveConstraint",
|
||||||
|
"EncoderNoRepeatNGramLogitsProcessor",
|
||||||
|
"EncoderRepetitionPenaltyLogitsProcessor",
|
||||||
|
"EpsilonLogitsWarper",
|
||||||
|
"EtaLogitsWarper",
|
||||||
|
"ExponentialDecayLengthPenalty",
|
||||||
"ForcedBOSTokenLogitsProcessor",
|
"ForcedBOSTokenLogitsProcessor",
|
||||||
"ForcedEOSTokenLogitsProcessor",
|
"ForcedEOSTokenLogitsProcessor",
|
||||||
|
"ForceTokensLogitsProcessor",
|
||||||
"GenerationMixin",
|
"GenerationMixin",
|
||||||
"HammingDiversityLogitsProcessor",
|
"HammingDiversityLogitsProcessor",
|
||||||
"InfNanRemoveLogitsProcessor",
|
"InfNanRemoveLogitsProcessor",
|
||||||
|
"LogitNormalization",
|
||||||
"LogitsProcessor",
|
"LogitsProcessor",
|
||||||
"LogitsProcessorList",
|
"LogitsProcessorList",
|
||||||
"LogitsWarper",
|
"LogitsWarper",
|
||||||
@@ -1021,10 +1030,14 @@ else:
|
|||||||
"SequenceBiasLogitsProcessor",
|
"SequenceBiasLogitsProcessor",
|
||||||
"StoppingCriteria",
|
"StoppingCriteria",
|
||||||
"StoppingCriteriaList",
|
"StoppingCriteriaList",
|
||||||
|
"SuppressTokensAtBeginLogitsProcessor",
|
||||||
|
"SuppressTokensLogitsProcessor",
|
||||||
"TemperatureLogitsWarper",
|
"TemperatureLogitsWarper",
|
||||||
"TopKLogitsWarper",
|
"TopKLogitsWarper",
|
||||||
"TopPLogitsWarper",
|
"TopPLogitsWarper",
|
||||||
"TypicalLogitsWarper",
|
"TypicalLogitsWarper",
|
||||||
|
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
|
||||||
|
"WhisperTimeStampLogitsProcessor",
|
||||||
"top_k_top_p_filtering",
|
"top_k_top_p_filtering",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -3098,6 +3111,7 @@ else:
|
|||||||
[
|
[
|
||||||
"TFForcedBOSTokenLogitsProcessor",
|
"TFForcedBOSTokenLogitsProcessor",
|
||||||
"TFForcedEOSTokenLogitsProcessor",
|
"TFForcedEOSTokenLogitsProcessor",
|
||||||
|
"TFForceTokensLogitsProcessor",
|
||||||
"TFGenerationMixin",
|
"TFGenerationMixin",
|
||||||
"TFLogitsProcessor",
|
"TFLogitsProcessor",
|
||||||
"TFLogitsProcessorList",
|
"TFLogitsProcessorList",
|
||||||
@@ -3106,6 +3120,8 @@ else:
|
|||||||
"TFNoBadWordsLogitsProcessor",
|
"TFNoBadWordsLogitsProcessor",
|
||||||
"TFNoRepeatNGramLogitsProcessor",
|
"TFNoRepeatNGramLogitsProcessor",
|
||||||
"TFRepetitionPenaltyLogitsProcessor",
|
"TFRepetitionPenaltyLogitsProcessor",
|
||||||
|
"TFSuppressTokensAtBeginLogitsProcessor",
|
||||||
|
"TFSuppressTokensLogitsProcessor",
|
||||||
"TFTemperatureLogitsWarper",
|
"TFTemperatureLogitsWarper",
|
||||||
"TFTopKLogitsWarper",
|
"TFTopKLogitsWarper",
|
||||||
"TFTopPLogitsWarper",
|
"TFTopPLogitsWarper",
|
||||||
@@ -3796,14 +3812,18 @@ else:
|
|||||||
[
|
[
|
||||||
"FlaxForcedBOSTokenLogitsProcessor",
|
"FlaxForcedBOSTokenLogitsProcessor",
|
||||||
"FlaxForcedEOSTokenLogitsProcessor",
|
"FlaxForcedEOSTokenLogitsProcessor",
|
||||||
|
"FlaxForceTokensLogitsProcessor",
|
||||||
"FlaxGenerationMixin",
|
"FlaxGenerationMixin",
|
||||||
"FlaxLogitsProcessor",
|
"FlaxLogitsProcessor",
|
||||||
"FlaxLogitsProcessorList",
|
"FlaxLogitsProcessorList",
|
||||||
"FlaxLogitsWarper",
|
"FlaxLogitsWarper",
|
||||||
"FlaxMinLengthLogitsProcessor",
|
"FlaxMinLengthLogitsProcessor",
|
||||||
"FlaxTemperatureLogitsWarper",
|
"FlaxTemperatureLogitsWarper",
|
||||||
|
"FlaxSuppressTokensAtBeginLogitsProcessor",
|
||||||
|
"FlaxSuppressTokensLogitsProcessor",
|
||||||
"FlaxTopKLogitsWarper",
|
"FlaxTopKLogitsWarper",
|
||||||
"FlaxTopPLogitsWarper",
|
"FlaxTopPLogitsWarper",
|
||||||
|
"FlaxWhisperTimeStampLogitsProcessor",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["generation_flax_utils"] = []
|
_import_structure["generation_flax_utils"] = []
|
||||||
@@ -4938,17 +4958,26 @@ if TYPE_CHECKING:
|
|||||||
TextDatasetForNextSentencePrediction,
|
TextDatasetForNextSentencePrediction,
|
||||||
)
|
)
|
||||||
from .generation import (
|
from .generation import (
|
||||||
|
AlternatingCodebooksLogitsProcessor,
|
||||||
BeamScorer,
|
BeamScorer,
|
||||||
BeamSearchScorer,
|
BeamSearchScorer,
|
||||||
|
ClassifierFreeGuidanceLogitsProcessor,
|
||||||
ConstrainedBeamSearchScorer,
|
ConstrainedBeamSearchScorer,
|
||||||
Constraint,
|
Constraint,
|
||||||
ConstraintListState,
|
ConstraintListState,
|
||||||
DisjunctiveConstraint,
|
DisjunctiveConstraint,
|
||||||
|
EncoderNoRepeatNGramLogitsProcessor,
|
||||||
|
EncoderRepetitionPenaltyLogitsProcessor,
|
||||||
|
EpsilonLogitsWarper,
|
||||||
|
EtaLogitsWarper,
|
||||||
|
ExponentialDecayLengthPenalty,
|
||||||
ForcedBOSTokenLogitsProcessor,
|
ForcedBOSTokenLogitsProcessor,
|
||||||
ForcedEOSTokenLogitsProcessor,
|
ForcedEOSTokenLogitsProcessor,
|
||||||
|
ForceTokensLogitsProcessor,
|
||||||
GenerationMixin,
|
GenerationMixin,
|
||||||
HammingDiversityLogitsProcessor,
|
HammingDiversityLogitsProcessor,
|
||||||
InfNanRemoveLogitsProcessor,
|
InfNanRemoveLogitsProcessor,
|
||||||
|
LogitNormalization,
|
||||||
LogitsProcessor,
|
LogitsProcessor,
|
||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
LogitsWarper,
|
LogitsWarper,
|
||||||
@@ -4964,10 +4993,14 @@ if TYPE_CHECKING:
|
|||||||
SequenceBiasLogitsProcessor,
|
SequenceBiasLogitsProcessor,
|
||||||
StoppingCriteria,
|
StoppingCriteria,
|
||||||
StoppingCriteriaList,
|
StoppingCriteriaList,
|
||||||
|
SuppressTokensAtBeginLogitsProcessor,
|
||||||
|
SuppressTokensLogitsProcessor,
|
||||||
TemperatureLogitsWarper,
|
TemperatureLogitsWarper,
|
||||||
TopKLogitsWarper,
|
TopKLogitsWarper,
|
||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
TypicalLogitsWarper,
|
TypicalLogitsWarper,
|
||||||
|
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||||
|
WhisperTimeStampLogitsProcessor,
|
||||||
top_k_top_p_filtering,
|
top_k_top_p_filtering,
|
||||||
)
|
)
|
||||||
from .modeling_utils import PreTrainedModel
|
from .modeling_utils import PreTrainedModel
|
||||||
@@ -6662,6 +6695,7 @@ if TYPE_CHECKING:
|
|||||||
from .generation import (
|
from .generation import (
|
||||||
TFForcedBOSTokenLogitsProcessor,
|
TFForcedBOSTokenLogitsProcessor,
|
||||||
TFForcedEOSTokenLogitsProcessor,
|
TFForcedEOSTokenLogitsProcessor,
|
||||||
|
TFForceTokensLogitsProcessor,
|
||||||
TFGenerationMixin,
|
TFGenerationMixin,
|
||||||
TFLogitsProcessor,
|
TFLogitsProcessor,
|
||||||
TFLogitsProcessorList,
|
TFLogitsProcessorList,
|
||||||
@@ -6670,6 +6704,8 @@ if TYPE_CHECKING:
|
|||||||
TFNoBadWordsLogitsProcessor,
|
TFNoBadWordsLogitsProcessor,
|
||||||
TFNoRepeatNGramLogitsProcessor,
|
TFNoRepeatNGramLogitsProcessor,
|
||||||
TFRepetitionPenaltyLogitsProcessor,
|
TFRepetitionPenaltyLogitsProcessor,
|
||||||
|
TFSuppressTokensAtBeginLogitsProcessor,
|
||||||
|
TFSuppressTokensLogitsProcessor,
|
||||||
TFTemperatureLogitsWarper,
|
TFTemperatureLogitsWarper,
|
||||||
TFTopKLogitsWarper,
|
TFTopKLogitsWarper,
|
||||||
TFTopPLogitsWarper,
|
TFTopPLogitsWarper,
|
||||||
@@ -7221,14 +7257,18 @@ if TYPE_CHECKING:
|
|||||||
from .generation import (
|
from .generation import (
|
||||||
FlaxForcedBOSTokenLogitsProcessor,
|
FlaxForcedBOSTokenLogitsProcessor,
|
||||||
FlaxForcedEOSTokenLogitsProcessor,
|
FlaxForcedEOSTokenLogitsProcessor,
|
||||||
|
FlaxForceTokensLogitsProcessor,
|
||||||
FlaxGenerationMixin,
|
FlaxGenerationMixin,
|
||||||
FlaxLogitsProcessor,
|
FlaxLogitsProcessor,
|
||||||
FlaxLogitsProcessorList,
|
FlaxLogitsProcessorList,
|
||||||
FlaxLogitsWarper,
|
FlaxLogitsWarper,
|
||||||
FlaxMinLengthLogitsProcessor,
|
FlaxMinLengthLogitsProcessor,
|
||||||
|
FlaxSuppressTokensAtBeginLogitsProcessor,
|
||||||
|
FlaxSuppressTokensLogitsProcessor,
|
||||||
FlaxTemperatureLogitsWarper,
|
FlaxTemperatureLogitsWarper,
|
||||||
FlaxTopKLogitsWarper,
|
FlaxTopKLogitsWarper,
|
||||||
FlaxTopPLogitsWarper,
|
FlaxTopPLogitsWarper,
|
||||||
|
FlaxWhisperTimeStampLogitsProcessor,
|
||||||
)
|
)
|
||||||
from .modeling_flax_utils import FlaxPreTrainedModel
|
from .modeling_flax_utils import FlaxPreTrainedModel
|
||||||
|
|
||||||
|
|||||||
@@ -41,12 +41,19 @@ else:
|
|||||||
"ConstrainedBeamSearchScorer",
|
"ConstrainedBeamSearchScorer",
|
||||||
]
|
]
|
||||||
_import_structure["logits_process"] = [
|
_import_structure["logits_process"] = [
|
||||||
|
"AlternatingCodebooksLogitsProcessor",
|
||||||
|
"ClassifierFreeGuidanceLogitsProcessor",
|
||||||
|
"EncoderNoRepeatNGramLogitsProcessor",
|
||||||
|
"EncoderRepetitionPenaltyLogitsProcessor",
|
||||||
"EpsilonLogitsWarper",
|
"EpsilonLogitsWarper",
|
||||||
"EtaLogitsWarper",
|
"EtaLogitsWarper",
|
||||||
|
"ExponentialDecayLengthPenalty",
|
||||||
"ForcedBOSTokenLogitsProcessor",
|
"ForcedBOSTokenLogitsProcessor",
|
||||||
"ForcedEOSTokenLogitsProcessor",
|
"ForcedEOSTokenLogitsProcessor",
|
||||||
|
"ForceTokensLogitsProcessor",
|
||||||
"HammingDiversityLogitsProcessor",
|
"HammingDiversityLogitsProcessor",
|
||||||
"InfNanRemoveLogitsProcessor",
|
"InfNanRemoveLogitsProcessor",
|
||||||
|
"LogitNormalization",
|
||||||
"LogitsProcessor",
|
"LogitsProcessor",
|
||||||
"LogitsProcessorList",
|
"LogitsProcessorList",
|
||||||
"LogitsWarper",
|
"LogitsWarper",
|
||||||
@@ -57,15 +64,14 @@ else:
|
|||||||
"PrefixConstrainedLogitsProcessor",
|
"PrefixConstrainedLogitsProcessor",
|
||||||
"RepetitionPenaltyLogitsProcessor",
|
"RepetitionPenaltyLogitsProcessor",
|
||||||
"SequenceBiasLogitsProcessor",
|
"SequenceBiasLogitsProcessor",
|
||||||
"EncoderRepetitionPenaltyLogitsProcessor",
|
"SuppressTokensLogitsProcessor",
|
||||||
|
"SuppressTokensAtBeginLogitsProcessor",
|
||||||
"TemperatureLogitsWarper",
|
"TemperatureLogitsWarper",
|
||||||
"TopKLogitsWarper",
|
"TopKLogitsWarper",
|
||||||
"TopPLogitsWarper",
|
"TopPLogitsWarper",
|
||||||
"TypicalLogitsWarper",
|
"TypicalLogitsWarper",
|
||||||
"EncoderNoRepeatNGramLogitsProcessor",
|
|
||||||
"ExponentialDecayLengthPenalty",
|
|
||||||
"LogitNormalization",
|
|
||||||
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
|
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
|
||||||
|
"WhisperTimeStampLogitsProcessor",
|
||||||
]
|
]
|
||||||
_import_structure["stopping_criteria"] = [
|
_import_structure["stopping_criteria"] = [
|
||||||
"MaxNewTokensCriteria",
|
"MaxNewTokensCriteria",
|
||||||
@@ -99,6 +105,7 @@ else:
|
|||||||
_import_structure["tf_logits_process"] = [
|
_import_structure["tf_logits_process"] = [
|
||||||
"TFForcedBOSTokenLogitsProcessor",
|
"TFForcedBOSTokenLogitsProcessor",
|
||||||
"TFForcedEOSTokenLogitsProcessor",
|
"TFForcedEOSTokenLogitsProcessor",
|
||||||
|
"TFForceTokensLogitsProcessor",
|
||||||
"TFLogitsProcessor",
|
"TFLogitsProcessor",
|
||||||
"TFLogitsProcessorList",
|
"TFLogitsProcessorList",
|
||||||
"TFLogitsWarper",
|
"TFLogitsWarper",
|
||||||
@@ -106,12 +113,11 @@ else:
|
|||||||
"TFNoBadWordsLogitsProcessor",
|
"TFNoBadWordsLogitsProcessor",
|
||||||
"TFNoRepeatNGramLogitsProcessor",
|
"TFNoRepeatNGramLogitsProcessor",
|
||||||
"TFRepetitionPenaltyLogitsProcessor",
|
"TFRepetitionPenaltyLogitsProcessor",
|
||||||
|
"TFSuppressTokensAtBeginLogitsProcessor",
|
||||||
|
"TFSuppressTokensLogitsProcessor",
|
||||||
"TFTemperatureLogitsWarper",
|
"TFTemperatureLogitsWarper",
|
||||||
"TFTopKLogitsWarper",
|
"TFTopKLogitsWarper",
|
||||||
"TFTopPLogitsWarper",
|
"TFTopPLogitsWarper",
|
||||||
"TFForceTokensLogitsProcessor",
|
|
||||||
"TFSuppressTokensAtBeginLogitsProcessor",
|
|
||||||
"TFSuppressTokensLogitsProcessor",
|
|
||||||
]
|
]
|
||||||
_import_structure["tf_utils"] = [
|
_import_structure["tf_utils"] = [
|
||||||
"TFGenerationMixin",
|
"TFGenerationMixin",
|
||||||
@@ -137,13 +143,17 @@ else:
|
|||||||
_import_structure["flax_logits_process"] = [
|
_import_structure["flax_logits_process"] = [
|
||||||
"FlaxForcedBOSTokenLogitsProcessor",
|
"FlaxForcedBOSTokenLogitsProcessor",
|
||||||
"FlaxForcedEOSTokenLogitsProcessor",
|
"FlaxForcedEOSTokenLogitsProcessor",
|
||||||
|
"FlaxForceTokensLogitsProcessor",
|
||||||
"FlaxLogitsProcessor",
|
"FlaxLogitsProcessor",
|
||||||
"FlaxLogitsProcessorList",
|
"FlaxLogitsProcessorList",
|
||||||
"FlaxLogitsWarper",
|
"FlaxLogitsWarper",
|
||||||
"FlaxMinLengthLogitsProcessor",
|
"FlaxMinLengthLogitsProcessor",
|
||||||
|
"FlaxSuppressTokensAtBeginLogitsProcessor",
|
||||||
|
"FlaxSuppressTokensLogitsProcessor",
|
||||||
"FlaxTemperatureLogitsWarper",
|
"FlaxTemperatureLogitsWarper",
|
||||||
"FlaxTopKLogitsWarper",
|
"FlaxTopKLogitsWarper",
|
||||||
"FlaxTopPLogitsWarper",
|
"FlaxTopPLogitsWarper",
|
||||||
|
"FlaxWhisperTimeStampLogitsProcessor",
|
||||||
]
|
]
|
||||||
_import_structure["flax_utils"] = [
|
_import_structure["flax_utils"] = [
|
||||||
"FlaxGenerationMixin",
|
"FlaxGenerationMixin",
|
||||||
@@ -165,6 +175,8 @@ if TYPE_CHECKING:
|
|||||||
from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint
|
from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint
|
||||||
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||||
from .logits_process import (
|
from .logits_process import (
|
||||||
|
AlternatingCodebooksLogitsProcessor,
|
||||||
|
ClassifierFreeGuidanceLogitsProcessor,
|
||||||
EncoderNoRepeatNGramLogitsProcessor,
|
EncoderNoRepeatNGramLogitsProcessor,
|
||||||
EncoderRepetitionPenaltyLogitsProcessor,
|
EncoderRepetitionPenaltyLogitsProcessor,
|
||||||
EpsilonLogitsWarper,
|
EpsilonLogitsWarper,
|
||||||
@@ -172,6 +184,7 @@ if TYPE_CHECKING:
|
|||||||
ExponentialDecayLengthPenalty,
|
ExponentialDecayLengthPenalty,
|
||||||
ForcedBOSTokenLogitsProcessor,
|
ForcedBOSTokenLogitsProcessor,
|
||||||
ForcedEOSTokenLogitsProcessor,
|
ForcedEOSTokenLogitsProcessor,
|
||||||
|
ForceTokensLogitsProcessor,
|
||||||
HammingDiversityLogitsProcessor,
|
HammingDiversityLogitsProcessor,
|
||||||
InfNanRemoveLogitsProcessor,
|
InfNanRemoveLogitsProcessor,
|
||||||
LogitNormalization,
|
LogitNormalization,
|
||||||
@@ -185,11 +198,14 @@ if TYPE_CHECKING:
|
|||||||
PrefixConstrainedLogitsProcessor,
|
PrefixConstrainedLogitsProcessor,
|
||||||
RepetitionPenaltyLogitsProcessor,
|
RepetitionPenaltyLogitsProcessor,
|
||||||
SequenceBiasLogitsProcessor,
|
SequenceBiasLogitsProcessor,
|
||||||
|
SuppressTokensAtBeginLogitsProcessor,
|
||||||
|
SuppressTokensLogitsProcessor,
|
||||||
TemperatureLogitsWarper,
|
TemperatureLogitsWarper,
|
||||||
TopKLogitsWarper,
|
TopKLogitsWarper,
|
||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
TypicalLogitsWarper,
|
TypicalLogitsWarper,
|
||||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||||
|
WhisperTimeStampLogitsProcessor,
|
||||||
)
|
)
|
||||||
from .stopping_criteria import (
|
from .stopping_criteria import (
|
||||||
MaxLengthCriteria,
|
MaxLengthCriteria,
|
||||||
@@ -261,13 +277,17 @@ if TYPE_CHECKING:
|
|||||||
from .flax_logits_process import (
|
from .flax_logits_process import (
|
||||||
FlaxForcedBOSTokenLogitsProcessor,
|
FlaxForcedBOSTokenLogitsProcessor,
|
||||||
FlaxForcedEOSTokenLogitsProcessor,
|
FlaxForcedEOSTokenLogitsProcessor,
|
||||||
|
FlaxForceTokensLogitsProcessor,
|
||||||
FlaxLogitsProcessor,
|
FlaxLogitsProcessor,
|
||||||
FlaxLogitsProcessorList,
|
FlaxLogitsProcessorList,
|
||||||
FlaxLogitsWarper,
|
FlaxLogitsWarper,
|
||||||
FlaxMinLengthLogitsProcessor,
|
FlaxMinLengthLogitsProcessor,
|
||||||
|
FlaxSuppressTokensAtBeginLogitsProcessor,
|
||||||
|
FlaxSuppressTokensLogitsProcessor,
|
||||||
FlaxTemperatureLogitsWarper,
|
FlaxTemperatureLogitsWarper,
|
||||||
FlaxTopKLogitsWarper,
|
FlaxTopKLogitsWarper,
|
||||||
FlaxTopPLogitsWarper,
|
FlaxTopPLogitsWarper,
|
||||||
|
FlaxWhisperTimeStampLogitsProcessor,
|
||||||
)
|
)
|
||||||
from .flax_utils import FlaxBeamSearchOutput, FlaxGenerationMixin, FlaxGreedySearchOutput, FlaxSampleOutput
|
from .flax_utils import FlaxBeamSearchOutput, FlaxGenerationMixin, FlaxGreedySearchOutput, FlaxSampleOutput
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -16,6 +16,13 @@ class FlaxForcedEOSTokenLogitsProcessor(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["flax"])
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxForceTokensLogitsProcessor(metaclass=DummyObject):
|
||||||
|
_backends = ["flax"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
class FlaxGenerationMixin(metaclass=DummyObject):
|
class FlaxGenerationMixin(metaclass=DummyObject):
|
||||||
_backends = ["flax"]
|
_backends = ["flax"]
|
||||||
|
|
||||||
@@ -51,6 +58,20 @@ class FlaxMinLengthLogitsProcessor(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["flax"])
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxSuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject):
|
||||||
|
_backends = ["flax"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxSuppressTokensLogitsProcessor(metaclass=DummyObject):
|
||||||
|
_backends = ["flax"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
class FlaxTemperatureLogitsWarper(metaclass=DummyObject):
|
class FlaxTemperatureLogitsWarper(metaclass=DummyObject):
|
||||||
_backends = ["flax"]
|
_backends = ["flax"]
|
||||||
|
|
||||||
@@ -72,6 +93,13 @@ class FlaxTopPLogitsWarper(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["flax"])
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxWhisperTimeStampLogitsProcessor(metaclass=DummyObject):
|
||||||
|
_backends = ["flax"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
class FlaxPreTrainedModel(metaclass=DummyObject):
|
class FlaxPreTrainedModel(metaclass=DummyObject):
|
||||||
_backends = ["flax"]
|
_backends = ["flax"]
|
||||||
|
|
||||||
|
|||||||
@@ -79,6 +79,13 @@ class TextDatasetForNextSentencePrediction(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class AlternatingCodebooksLogitsProcessor(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class BeamScorer(metaclass=DummyObject):
|
class BeamScorer(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
@@ -93,6 +100,13 @@ class BeamSearchScorer(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class ClassifierFreeGuidanceLogitsProcessor(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class ConstrainedBeamSearchScorer(metaclass=DummyObject):
|
class ConstrainedBeamSearchScorer(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
@@ -121,6 +135,41 @@ class DisjunctiveConstraint(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderNoRepeatNGramLogitsProcessor(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderRepetitionPenaltyLogitsProcessor(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class EpsilonLogitsWarper(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class EtaLogitsWarper(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class ExponentialDecayLengthPenalty(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class ForcedBOSTokenLogitsProcessor(metaclass=DummyObject):
|
class ForcedBOSTokenLogitsProcessor(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
@@ -135,6 +184,13 @@ class ForcedEOSTokenLogitsProcessor(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class ForceTokensLogitsProcessor(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class GenerationMixin(metaclass=DummyObject):
|
class GenerationMixin(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
@@ -156,6 +212,13 @@ class InfNanRemoveLogitsProcessor(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class LogitNormalization(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class LogitsProcessor(metaclass=DummyObject):
|
class LogitsProcessor(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
@@ -261,6 +324,20 @@ class StoppingCriteriaList(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class SuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class SuppressTokensLogitsProcessor(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class TemperatureLogitsWarper(metaclass=DummyObject):
|
class TemperatureLogitsWarper(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
@@ -289,6 +366,20 @@ class TypicalLogitsWarper(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class UnbatchedClassifierFreeGuidanceLogitsProcessor(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperTimeStampLogitsProcessor(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
def top_k_top_p_filtering(*args, **kwargs):
|
def top_k_top_p_filtering(*args, **kwargs):
|
||||||
requires_backends(top_k_top_p_filtering, ["torch"])
|
requires_backends(top_k_top_p_filtering, ["torch"])
|
||||||
|
|
||||||
|
|||||||
@@ -30,6 +30,13 @@ class TFForcedEOSTokenLogitsProcessor(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["tf"])
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFForceTokensLogitsProcessor(metaclass=DummyObject):
|
||||||
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
class TFGenerationMixin(metaclass=DummyObject):
|
class TFGenerationMixin(metaclass=DummyObject):
|
||||||
_backends = ["tf"]
|
_backends = ["tf"]
|
||||||
|
|
||||||
@@ -86,6 +93,20 @@ class TFRepetitionPenaltyLogitsProcessor(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["tf"])
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFSuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject):
|
||||||
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFSuppressTokensLogitsProcessor(metaclass=DummyObject):
|
||||||
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
class TFTemperatureLogitsWarper(metaclass=DummyObject):
|
class TFTemperatureLogitsWarper(metaclass=DummyObject):
|
||||||
_backends = ["tf"]
|
_backends = ["tf"]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user