Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ccb92be23d | ||
|
|
657eb26c49 | ||
|
|
13aef138ad | ||
|
|
6836e9dd43 | ||
|
|
e82040e12b | ||
|
|
9d42e402ef | ||
|
|
27f91578f1 | ||
|
|
6a029a8b9e |
@@ -641,6 +641,8 @@
|
||||
title: GIT
|
||||
- local: model_doc/groupvit
|
||||
title: GroupViT
|
||||
- local: model_doc/idefics
|
||||
title: IDEFICS
|
||||
- local: model_doc/instructblip
|
||||
title: InstructBLIP
|
||||
- local: model_doc/layoutlm
|
||||
@@ -692,8 +694,6 @@
|
||||
sections:
|
||||
- local: model_doc/decision_transformer
|
||||
title: Decision Transformer
|
||||
- local: model_doc/idefics
|
||||
title: IDEFICS
|
||||
- local: model_doc/trajectory_transformer
|
||||
title: Trajectory Transformer
|
||||
title: Reinforcement learning models
|
||||
|
||||
@@ -75,39 +75,104 @@ values. Here, for instance, it has two keys that are `sequences` and `scores`.
|
||||
We document here all output types.
|
||||
|
||||
|
||||
### GreedySearchOutput
|
||||
|
||||
[[autodoc]] generation.GreedySearchDecoderOnlyOutput
|
||||
### PyTorch
|
||||
|
||||
[[autodoc]] generation.GreedySearchEncoderDecoderOutput
|
||||
|
||||
[[autodoc]] generation.FlaxGreedySearchOutput
|
||||
|
||||
### SampleOutput
|
||||
|
||||
[[autodoc]] generation.SampleDecoderOnlyOutput
|
||||
[[autodoc]] generation.GreedySearchDecoderOnlyOutput
|
||||
|
||||
[[autodoc]] generation.SampleEncoderDecoderOutput
|
||||
|
||||
[[autodoc]] generation.FlaxSampleOutput
|
||||
|
||||
### BeamSearchOutput
|
||||
|
||||
[[autodoc]] generation.BeamSearchDecoderOnlyOutput
|
||||
[[autodoc]] generation.SampleDecoderOnlyOutput
|
||||
|
||||
[[autodoc]] generation.BeamSearchEncoderDecoderOutput
|
||||
|
||||
### BeamSampleOutput
|
||||
[[autodoc]] generation.BeamSearchDecoderOnlyOutput
|
||||
|
||||
[[autodoc]] generation.BeamSampleEncoderDecoderOutput
|
||||
|
||||
[[autodoc]] generation.BeamSampleDecoderOnlyOutput
|
||||
|
||||
[[autodoc]] generation.BeamSampleEncoderDecoderOutput
|
||||
[[autodoc]] generation.ContrastiveSearchEncoderDecoderOutput
|
||||
|
||||
[[autodoc]] generation.ContrastiveSearchDecoderOnlyOutput
|
||||
|
||||
### TensorFlow
|
||||
|
||||
[[autodoc]] generation.TFGreedySearchEncoderDecoderOutput
|
||||
|
||||
[[autodoc]] generation.TFGreedySearchDecoderOnlyOutput
|
||||
|
||||
[[autodoc]] generation.TFSampleEncoderDecoderOutput
|
||||
|
||||
[[autodoc]] generation.TFSampleDecoderOnlyOutput
|
||||
|
||||
[[autodoc]] generation.TFBeamSearchEncoderDecoderOutput
|
||||
|
||||
[[autodoc]] generation.TFBeamSearchDecoderOnlyOutput
|
||||
|
||||
[[autodoc]] generation.TFBeamSampleEncoderDecoderOutput
|
||||
|
||||
[[autodoc]] generation.TFBeamSampleDecoderOnlyOutput
|
||||
|
||||
[[autodoc]] generation.TFContrastiveSearchEncoderDecoderOutput
|
||||
|
||||
[[autodoc]] generation.TFContrastiveSearchDecoderOnlyOutput
|
||||
|
||||
### FLAX
|
||||
|
||||
[[autodoc]] generation.FlaxSampleOutput
|
||||
|
||||
[[autodoc]] generation.FlaxGreedySearchOutput
|
||||
|
||||
[[autodoc]] generation.FlaxBeamSearchOutput
|
||||
|
||||
## LogitsProcessor
|
||||
|
||||
A [`LogitsProcessor`] can be used to modify the prediction scores of a language model head for
|
||||
generation.
|
||||
|
||||
### PyTorch
|
||||
|
||||
[[autodoc]] AlternatingCodebooksLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] ClassifierFreeGuidanceLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] EncoderNoRepeatNGramLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] EncoderRepetitionPenaltyLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] EpsilonLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] EtaLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] ExponentialDecayLengthPenalty
|
||||
- __call__
|
||||
|
||||
[[autodoc]] ForcedBOSTokenLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] ForcedEOSTokenLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] ForceTokensLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] HammingDiversityLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] InfNanRemoveLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] LogitNormalization
|
||||
- __call__
|
||||
|
||||
[[autodoc]] LogitsProcessor
|
||||
- __call__
|
||||
|
||||
@@ -123,43 +188,54 @@ generation.
|
||||
[[autodoc]] MinNewTokensLengthLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TemperatureLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] RepetitionPenaltyLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TopPLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TopKLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TypicalLogitsWarper
|
||||
[[autodoc]] NoBadWordsLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] NoRepeatNGramLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] SequenceBiasLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] NoBadWordsLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] PrefixConstrainedLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] HammingDiversityLogitsProcessor
|
||||
[[autodoc]] RepetitionPenaltyLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] ForcedBOSTokenLogitsProcessor
|
||||
[[autodoc]] SequenceBiasLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] ForcedEOSTokenLogitsProcessor
|
||||
[[autodoc]] SuppressTokensAtBeginLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] InfNanRemoveLogitsProcessor
|
||||
[[autodoc]] SuppressTokensLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TemperatureLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TopKLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TopPLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TypicalLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] UnbatchedClassifierFreeGuidanceLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] WhisperTimeStampLogitsProcessor
|
||||
- __call__
|
||||
|
||||
### TensorFlow
|
||||
|
||||
[[autodoc]] TFForcedBOSTokenLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFForcedEOSTokenLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFForceTokensLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFLogitsProcessor
|
||||
@@ -171,15 +247,6 @@ generation.
|
||||
[[autodoc]] TFLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFTemperatureLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFTopPLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFTopKLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFMinLengthLogitsProcessor
|
||||
- __call__
|
||||
|
||||
@@ -192,10 +259,30 @@ generation.
|
||||
[[autodoc]] TFRepetitionPenaltyLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFForcedBOSTokenLogitsProcessor
|
||||
[[autodoc]] TFSuppressTokensAtBeginLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFForcedEOSTokenLogitsProcessor
|
||||
[[autodoc]] TFSuppressTokensLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFTemperatureLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFTopKLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] TFTopPLogitsWarper
|
||||
- __call__
|
||||
|
||||
### FLAX
|
||||
|
||||
[[autodoc]] FlaxForcedBOSTokenLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] FlaxForcedEOSTokenLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] FlaxForceTokensLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] FlaxLogitsProcessor
|
||||
@@ -207,27 +294,30 @@ generation.
|
||||
[[autodoc]] FlaxLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] FlaxTemperatureLogitsWarper
|
||||
[[autodoc]] FlaxMinLengthLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] FlaxTopPLogitsWarper
|
||||
[[autodoc]] FlaxSuppressTokensAtBeginLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] FlaxSuppressTokensLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] FlaxTemperatureLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] FlaxTopKLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] FlaxForcedBOSTokenLogitsProcessor
|
||||
[[autodoc]] FlaxTopPLogitsWarper
|
||||
- __call__
|
||||
|
||||
[[autodoc]] FlaxForcedEOSTokenLogitsProcessor
|
||||
- __call__
|
||||
|
||||
[[autodoc]] FlaxMinLengthLogitsProcessor
|
||||
[[autodoc]] FlaxWhisperTimeStampLogitsProcessor
|
||||
- __call__
|
||||
|
||||
## StoppingCriteria
|
||||
|
||||
A [`StoppingCriteria`] can be used to change when to stop generation (other than EOS token).
|
||||
A [`StoppingCriteria`] can be used to change when to stop generation (other than EOS token). Please note that this is exclusivelly available to our PyTorch implementations.
|
||||
|
||||
[[autodoc]] StoppingCriteria
|
||||
- __call__
|
||||
@@ -243,7 +333,7 @@ A [`StoppingCriteria`] can be used to change when to stop generation (other than
|
||||
|
||||
## Constraints
|
||||
|
||||
A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output.
|
||||
A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output. Please note that this is exclusivelly available to our PyTorch implementations.
|
||||
|
||||
[[autodoc]] Constraint
|
||||
|
||||
|
||||
2
setup.py
2
setup.py
@@ -425,7 +425,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.32.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.32.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||
author_email="transformers@huggingface.co",
|
||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.32.0"
|
||||
__version__ = "4.32.1"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -995,17 +995,26 @@ else:
|
||||
_import_structure["deepspeed"] = []
|
||||
_import_structure["generation"].extend(
|
||||
[
|
||||
"AlternatingCodebooksLogitsProcessor",
|
||||
"BeamScorer",
|
||||
"BeamSearchScorer",
|
||||
"ClassifierFreeGuidanceLogitsProcessor",
|
||||
"ConstrainedBeamSearchScorer",
|
||||
"Constraint",
|
||||
"ConstraintListState",
|
||||
"DisjunctiveConstraint",
|
||||
"EncoderNoRepeatNGramLogitsProcessor",
|
||||
"EncoderRepetitionPenaltyLogitsProcessor",
|
||||
"EpsilonLogitsWarper",
|
||||
"EtaLogitsWarper",
|
||||
"ExponentialDecayLengthPenalty",
|
||||
"ForcedBOSTokenLogitsProcessor",
|
||||
"ForcedEOSTokenLogitsProcessor",
|
||||
"ForceTokensLogitsProcessor",
|
||||
"GenerationMixin",
|
||||
"HammingDiversityLogitsProcessor",
|
||||
"InfNanRemoveLogitsProcessor",
|
||||
"LogitNormalization",
|
||||
"LogitsProcessor",
|
||||
"LogitsProcessorList",
|
||||
"LogitsWarper",
|
||||
@@ -1021,10 +1030,14 @@ else:
|
||||
"SequenceBiasLogitsProcessor",
|
||||
"StoppingCriteria",
|
||||
"StoppingCriteriaList",
|
||||
"SuppressTokensAtBeginLogitsProcessor",
|
||||
"SuppressTokensLogitsProcessor",
|
||||
"TemperatureLogitsWarper",
|
||||
"TopKLogitsWarper",
|
||||
"TopPLogitsWarper",
|
||||
"TypicalLogitsWarper",
|
||||
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
|
||||
"WhisperTimeStampLogitsProcessor",
|
||||
"top_k_top_p_filtering",
|
||||
]
|
||||
)
|
||||
@@ -3098,6 +3111,7 @@ else:
|
||||
[
|
||||
"TFForcedBOSTokenLogitsProcessor",
|
||||
"TFForcedEOSTokenLogitsProcessor",
|
||||
"TFForceTokensLogitsProcessor",
|
||||
"TFGenerationMixin",
|
||||
"TFLogitsProcessor",
|
||||
"TFLogitsProcessorList",
|
||||
@@ -3106,6 +3120,8 @@ else:
|
||||
"TFNoBadWordsLogitsProcessor",
|
||||
"TFNoRepeatNGramLogitsProcessor",
|
||||
"TFRepetitionPenaltyLogitsProcessor",
|
||||
"TFSuppressTokensAtBeginLogitsProcessor",
|
||||
"TFSuppressTokensLogitsProcessor",
|
||||
"TFTemperatureLogitsWarper",
|
||||
"TFTopKLogitsWarper",
|
||||
"TFTopPLogitsWarper",
|
||||
@@ -3796,14 +3812,18 @@ else:
|
||||
[
|
||||
"FlaxForcedBOSTokenLogitsProcessor",
|
||||
"FlaxForcedEOSTokenLogitsProcessor",
|
||||
"FlaxForceTokensLogitsProcessor",
|
||||
"FlaxGenerationMixin",
|
||||
"FlaxLogitsProcessor",
|
||||
"FlaxLogitsProcessorList",
|
||||
"FlaxLogitsWarper",
|
||||
"FlaxMinLengthLogitsProcessor",
|
||||
"FlaxTemperatureLogitsWarper",
|
||||
"FlaxSuppressTokensAtBeginLogitsProcessor",
|
||||
"FlaxSuppressTokensLogitsProcessor",
|
||||
"FlaxTopKLogitsWarper",
|
||||
"FlaxTopPLogitsWarper",
|
||||
"FlaxWhisperTimeStampLogitsProcessor",
|
||||
]
|
||||
)
|
||||
_import_structure["generation_flax_utils"] = []
|
||||
@@ -4938,17 +4958,26 @@ if TYPE_CHECKING:
|
||||
TextDatasetForNextSentencePrediction,
|
||||
)
|
||||
from .generation import (
|
||||
AlternatingCodebooksLogitsProcessor,
|
||||
BeamScorer,
|
||||
BeamSearchScorer,
|
||||
ClassifierFreeGuidanceLogitsProcessor,
|
||||
ConstrainedBeamSearchScorer,
|
||||
Constraint,
|
||||
ConstraintListState,
|
||||
DisjunctiveConstraint,
|
||||
EncoderNoRepeatNGramLogitsProcessor,
|
||||
EncoderRepetitionPenaltyLogitsProcessor,
|
||||
EpsilonLogitsWarper,
|
||||
EtaLogitsWarper,
|
||||
ExponentialDecayLengthPenalty,
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
ForceTokensLogitsProcessor,
|
||||
GenerationMixin,
|
||||
HammingDiversityLogitsProcessor,
|
||||
InfNanRemoveLogitsProcessor,
|
||||
LogitNormalization,
|
||||
LogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
LogitsWarper,
|
||||
@@ -4964,10 +4993,14 @@ if TYPE_CHECKING:
|
||||
SequenceBiasLogitsProcessor,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
SuppressTokensAtBeginLogitsProcessor,
|
||||
SuppressTokensLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
WhisperTimeStampLogitsProcessor,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
from .modeling_utils import PreTrainedModel
|
||||
@@ -6662,6 +6695,7 @@ if TYPE_CHECKING:
|
||||
from .generation import (
|
||||
TFForcedBOSTokenLogitsProcessor,
|
||||
TFForcedEOSTokenLogitsProcessor,
|
||||
TFForceTokensLogitsProcessor,
|
||||
TFGenerationMixin,
|
||||
TFLogitsProcessor,
|
||||
TFLogitsProcessorList,
|
||||
@@ -6670,6 +6704,8 @@ if TYPE_CHECKING:
|
||||
TFNoBadWordsLogitsProcessor,
|
||||
TFNoRepeatNGramLogitsProcessor,
|
||||
TFRepetitionPenaltyLogitsProcessor,
|
||||
TFSuppressTokensAtBeginLogitsProcessor,
|
||||
TFSuppressTokensLogitsProcessor,
|
||||
TFTemperatureLogitsWarper,
|
||||
TFTopKLogitsWarper,
|
||||
TFTopPLogitsWarper,
|
||||
@@ -7221,14 +7257,18 @@ if TYPE_CHECKING:
|
||||
from .generation import (
|
||||
FlaxForcedBOSTokenLogitsProcessor,
|
||||
FlaxForcedEOSTokenLogitsProcessor,
|
||||
FlaxForceTokensLogitsProcessor,
|
||||
FlaxGenerationMixin,
|
||||
FlaxLogitsProcessor,
|
||||
FlaxLogitsProcessorList,
|
||||
FlaxLogitsWarper,
|
||||
FlaxMinLengthLogitsProcessor,
|
||||
FlaxSuppressTokensAtBeginLogitsProcessor,
|
||||
FlaxSuppressTokensLogitsProcessor,
|
||||
FlaxTemperatureLogitsWarper,
|
||||
FlaxTopKLogitsWarper,
|
||||
FlaxTopPLogitsWarper,
|
||||
FlaxWhisperTimeStampLogitsProcessor,
|
||||
)
|
||||
from .modeling_flax_utils import FlaxPreTrainedModel
|
||||
|
||||
|
||||
@@ -41,12 +41,19 @@ else:
|
||||
"ConstrainedBeamSearchScorer",
|
||||
]
|
||||
_import_structure["logits_process"] = [
|
||||
"AlternatingCodebooksLogitsProcessor",
|
||||
"ClassifierFreeGuidanceLogitsProcessor",
|
||||
"EncoderNoRepeatNGramLogitsProcessor",
|
||||
"EncoderRepetitionPenaltyLogitsProcessor",
|
||||
"EpsilonLogitsWarper",
|
||||
"EtaLogitsWarper",
|
||||
"ExponentialDecayLengthPenalty",
|
||||
"ForcedBOSTokenLogitsProcessor",
|
||||
"ForcedEOSTokenLogitsProcessor",
|
||||
"ForceTokensLogitsProcessor",
|
||||
"HammingDiversityLogitsProcessor",
|
||||
"InfNanRemoveLogitsProcessor",
|
||||
"LogitNormalization",
|
||||
"LogitsProcessor",
|
||||
"LogitsProcessorList",
|
||||
"LogitsWarper",
|
||||
@@ -57,15 +64,14 @@ else:
|
||||
"PrefixConstrainedLogitsProcessor",
|
||||
"RepetitionPenaltyLogitsProcessor",
|
||||
"SequenceBiasLogitsProcessor",
|
||||
"EncoderRepetitionPenaltyLogitsProcessor",
|
||||
"SuppressTokensLogitsProcessor",
|
||||
"SuppressTokensAtBeginLogitsProcessor",
|
||||
"TemperatureLogitsWarper",
|
||||
"TopKLogitsWarper",
|
||||
"TopPLogitsWarper",
|
||||
"TypicalLogitsWarper",
|
||||
"EncoderNoRepeatNGramLogitsProcessor",
|
||||
"ExponentialDecayLengthPenalty",
|
||||
"LogitNormalization",
|
||||
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
|
||||
"WhisperTimeStampLogitsProcessor",
|
||||
]
|
||||
_import_structure["stopping_criteria"] = [
|
||||
"MaxNewTokensCriteria",
|
||||
@@ -99,6 +105,7 @@ else:
|
||||
_import_structure["tf_logits_process"] = [
|
||||
"TFForcedBOSTokenLogitsProcessor",
|
||||
"TFForcedEOSTokenLogitsProcessor",
|
||||
"TFForceTokensLogitsProcessor",
|
||||
"TFLogitsProcessor",
|
||||
"TFLogitsProcessorList",
|
||||
"TFLogitsWarper",
|
||||
@@ -106,12 +113,11 @@ else:
|
||||
"TFNoBadWordsLogitsProcessor",
|
||||
"TFNoRepeatNGramLogitsProcessor",
|
||||
"TFRepetitionPenaltyLogitsProcessor",
|
||||
"TFSuppressTokensAtBeginLogitsProcessor",
|
||||
"TFSuppressTokensLogitsProcessor",
|
||||
"TFTemperatureLogitsWarper",
|
||||
"TFTopKLogitsWarper",
|
||||
"TFTopPLogitsWarper",
|
||||
"TFForceTokensLogitsProcessor",
|
||||
"TFSuppressTokensAtBeginLogitsProcessor",
|
||||
"TFSuppressTokensLogitsProcessor",
|
||||
]
|
||||
_import_structure["tf_utils"] = [
|
||||
"TFGenerationMixin",
|
||||
@@ -137,13 +143,17 @@ else:
|
||||
_import_structure["flax_logits_process"] = [
|
||||
"FlaxForcedBOSTokenLogitsProcessor",
|
||||
"FlaxForcedEOSTokenLogitsProcessor",
|
||||
"FlaxForceTokensLogitsProcessor",
|
||||
"FlaxLogitsProcessor",
|
||||
"FlaxLogitsProcessorList",
|
||||
"FlaxLogitsWarper",
|
||||
"FlaxMinLengthLogitsProcessor",
|
||||
"FlaxSuppressTokensAtBeginLogitsProcessor",
|
||||
"FlaxSuppressTokensLogitsProcessor",
|
||||
"FlaxTemperatureLogitsWarper",
|
||||
"FlaxTopKLogitsWarper",
|
||||
"FlaxTopPLogitsWarper",
|
||||
"FlaxWhisperTimeStampLogitsProcessor",
|
||||
]
|
||||
_import_structure["flax_utils"] = [
|
||||
"FlaxGenerationMixin",
|
||||
@@ -165,6 +175,8 @@ if TYPE_CHECKING:
|
||||
from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint
|
||||
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
|
||||
from .logits_process import (
|
||||
AlternatingCodebooksLogitsProcessor,
|
||||
ClassifierFreeGuidanceLogitsProcessor,
|
||||
EncoderNoRepeatNGramLogitsProcessor,
|
||||
EncoderRepetitionPenaltyLogitsProcessor,
|
||||
EpsilonLogitsWarper,
|
||||
@@ -172,6 +184,7 @@ if TYPE_CHECKING:
|
||||
ExponentialDecayLengthPenalty,
|
||||
ForcedBOSTokenLogitsProcessor,
|
||||
ForcedEOSTokenLogitsProcessor,
|
||||
ForceTokensLogitsProcessor,
|
||||
HammingDiversityLogitsProcessor,
|
||||
InfNanRemoveLogitsProcessor,
|
||||
LogitNormalization,
|
||||
@@ -185,11 +198,14 @@ if TYPE_CHECKING:
|
||||
PrefixConstrainedLogitsProcessor,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
SequenceBiasLogitsProcessor,
|
||||
SuppressTokensAtBeginLogitsProcessor,
|
||||
SuppressTokensLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
TypicalLogitsWarper,
|
||||
UnbatchedClassifierFreeGuidanceLogitsProcessor,
|
||||
WhisperTimeStampLogitsProcessor,
|
||||
)
|
||||
from .stopping_criteria import (
|
||||
MaxLengthCriteria,
|
||||
@@ -261,13 +277,17 @@ if TYPE_CHECKING:
|
||||
from .flax_logits_process import (
|
||||
FlaxForcedBOSTokenLogitsProcessor,
|
||||
FlaxForcedEOSTokenLogitsProcessor,
|
||||
FlaxForceTokensLogitsProcessor,
|
||||
FlaxLogitsProcessor,
|
||||
FlaxLogitsProcessorList,
|
||||
FlaxLogitsWarper,
|
||||
FlaxMinLengthLogitsProcessor,
|
||||
FlaxSuppressTokensAtBeginLogitsProcessor,
|
||||
FlaxSuppressTokensLogitsProcessor,
|
||||
FlaxTemperatureLogitsWarper,
|
||||
FlaxTopKLogitsWarper,
|
||||
FlaxTopPLogitsWarper,
|
||||
FlaxWhisperTimeStampLogitsProcessor,
|
||||
)
|
||||
from .flax_utils import FlaxBeamSearchOutput, FlaxGenerationMixin, FlaxGreedySearchOutput, FlaxSampleOutput
|
||||
else:
|
||||
|
||||
@@ -135,7 +135,7 @@ class BloomTokenizerFast(PreTrainedTokenizerFast):
|
||||
|
||||
if add_prefix_space:
|
||||
pre_tok_state = pre_tok_state.replace(b'"add_prefix_space":false', b'"add_prefix_space": true')
|
||||
decoder_state = pre_tok_state.replace(b'"add_prefix_space":false', b'"add_prefix_space": true')
|
||||
decoder_state = decoder_state.replace(b'"add_prefix_space":false', b'"add_prefix_space": true')
|
||||
self.backend_tokenizer.pre_tokenizer = pickle.loads(pre_tok_state)
|
||||
self.backend_tokenizer.decoder = pickle.loads(decoder_state)
|
||||
|
||||
|
||||
@@ -1314,9 +1314,7 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
|
||||
]
|
||||
if annotations is not None:
|
||||
annotations = [
|
||||
self.normalize_annotation(
|
||||
annotation, get_image_size(image, input_data_format), input_data_format=input_data_format
|
||||
)
|
||||
self.normalize_annotation(annotation, get_image_size(image, input_data_format))
|
||||
for annotation, image in zip(annotations, images)
|
||||
]
|
||||
|
||||
|
||||
@@ -1312,9 +1312,7 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
|
||||
]
|
||||
if annotations is not None:
|
||||
annotations = [
|
||||
self.normalize_annotation(
|
||||
annotation, get_image_size(image, input_data_format), input_data_format=input_data_format
|
||||
)
|
||||
self.normalize_annotation(annotation, get_image_size(image, input_data_format))
|
||||
for annotation, image in zip(annotations, images)
|
||||
]
|
||||
|
||||
|
||||
@@ -1284,9 +1284,7 @@ class DetrImageProcessor(BaseImageProcessor):
|
||||
]
|
||||
if annotations is not None:
|
||||
annotations = [
|
||||
self.normalize_annotation(
|
||||
annotation, get_image_size(image, input_data_format), input_data_format=input_data_format
|
||||
)
|
||||
self.normalize_annotation(annotation, get_image_size(image, input_data_format))
|
||||
for annotation, image in zip(annotations, images)
|
||||
]
|
||||
|
||||
|
||||
@@ -35,7 +35,6 @@ else:
|
||||
_import_structure["modeling_idefics"] = [
|
||||
"IDEFICS_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"IdeficsForVisionText2Text",
|
||||
"IdeficsGatedCrossAttentionLayer",
|
||||
"IdeficsModel",
|
||||
"IdeficsPreTrainedModel",
|
||||
]
|
||||
@@ -62,7 +61,6 @@ if TYPE_CHECKING:
|
||||
from .modeling_idefics import (
|
||||
IDEFICS_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
IdeficsForVisionText2Text,
|
||||
IdeficsGatedCrossAttentionLayer,
|
||||
IdeficsModel,
|
||||
IdeficsPreTrainedModel,
|
||||
)
|
||||
|
||||
@@ -1482,9 +1482,9 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
>>> from transformers import AutoTokenizer, IdeficsForVisionText2Text
|
||||
|
||||
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> model = IdeficsForVisionText2Text.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
||||
|
||||
@@ -220,13 +220,14 @@ class LlamaTokenizer(PreTrainedTokenizer):
|
||||
`unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
|
||||
`self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
|
||||
"""
|
||||
if self.legacy:
|
||||
return self.sp_model.encode(text, out_type=str)
|
||||
|
||||
unk_token_length = len(self.sp_model.encode(str(self.unk_token)))
|
||||
text = self.unk_token + text
|
||||
tokens = self.sp_model.encode(text, out_type=str)
|
||||
return tokens[unk_token_length:]
|
||||
if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
|
||||
return tokens
|
||||
|
||||
# 1. Encode string + prefix ex: "<unk> Hey"
|
||||
tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
|
||||
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
|
||||
return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
"""Converts a token (str) in an id using the vocab."""
|
||||
|
||||
@@ -363,6 +363,10 @@ class T5Tokenizer(PreTrainedTokenizer):
|
||||
tokens = tokens[1:]
|
||||
return tokens
|
||||
|
||||
@property
|
||||
def unk_token_length(self):
|
||||
return len(self.sp_model.encode(str(self.unk_token)))
|
||||
|
||||
def _tokenize(self, text, **kwargs):
|
||||
"""
|
||||
Returns a tokenized string.
|
||||
@@ -373,13 +377,14 @@ class T5Tokenizer(PreTrainedTokenizer):
|
||||
`unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
|
||||
`self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
|
||||
"""
|
||||
if self.legacy:
|
||||
return self.sp_model.encode(text, out_type=str)
|
||||
|
||||
unk_token_length = len(self.sp_model.encode(str(self.unk_token)))
|
||||
text = self.unk_token + text
|
||||
tokens = self.sp_model.encode(text, out_type=str)
|
||||
return tokens[unk_token_length:]
|
||||
if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
|
||||
return tokens
|
||||
|
||||
# 1. Encode string + prefix ex: "<unk> Hey"
|
||||
tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
|
||||
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
|
||||
return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
"""Converts a token (str) in an id using the vocab."""
|
||||
|
||||
@@ -16,6 +16,13 @@ class FlaxForcedEOSTokenLogitsProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxForceTokensLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxGenerationMixin(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
@@ -51,6 +58,20 @@ class FlaxMinLengthLogitsProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxSuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxSuppressTokensLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxTemperatureLogitsWarper(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
@@ -72,6 +93,13 @@ class FlaxTopPLogitsWarper(metaclass=DummyObject):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxWhisperTimeStampLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["flax"]
|
||||
|
||||
|
||||
@@ -79,6 +79,13 @@ class TextDatasetForNextSentencePrediction(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class AlternatingCodebooksLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class BeamScorer(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -93,6 +100,13 @@ class BeamSearchScorer(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class ClassifierFreeGuidanceLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class ConstrainedBeamSearchScorer(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -121,6 +135,41 @@ class DisjunctiveConstraint(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class EncoderNoRepeatNGramLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class EncoderRepetitionPenaltyLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class EpsilonLogitsWarper(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class EtaLogitsWarper(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class ExponentialDecayLengthPenalty(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class ForcedBOSTokenLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -135,6 +184,13 @@ class ForcedEOSTokenLogitsProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class ForceTokensLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class GenerationMixin(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -156,6 +212,13 @@ class InfNanRemoveLogitsProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class LogitNormalization(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class LogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -261,6 +324,20 @@ class StoppingCriteriaList(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class SuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class SuppressTokensLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class TemperatureLogitsWarper(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -289,6 +366,20 @@ class TypicalLogitsWarper(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class UnbatchedClassifierFreeGuidanceLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class WhisperTimeStampLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
def top_k_top_p_filtering(*args, **kwargs):
|
||||
requires_backends(top_k_top_p_filtering, ["torch"])
|
||||
|
||||
|
||||
@@ -30,6 +30,13 @@ class TFForcedEOSTokenLogitsProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFForceTokensLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFGenerationMixin(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
@@ -86,6 +93,20 @@ class TFRepetitionPenaltyLogitsProcessor(metaclass=DummyObject):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFSuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFSuppressTokensLogitsProcessor(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFTemperatureLogitsWarper(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
|
||||
@@ -133,3 +133,10 @@ class BloomTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
# maximum sequence length of the positoonal embeddings.
|
||||
self.assertGreaterEqual(len(self.tokenizer_class.pretrained_vocab_files_map), 1)
|
||||
self.assertGreaterEqual(len(list(self.tokenizer_class.pretrained_vocab_files_map.values())[0]), 1)
|
||||
|
||||
def test_add_prefix_space_fast(self):
|
||||
tokenizer_w_prefix = self.get_rust_tokenizer(add_prefix_space=True)
|
||||
tokenizer_wo_prefix = self.get_rust_tokenizer(add_prefix_space=False)
|
||||
tokens_w_prefix = tokenizer_w_prefix.tokenize("Hey")
|
||||
tokens_wo_prefix = tokenizer_wo_prefix.tokenize("Hey")
|
||||
self.assertNotEqual(tokens_w_prefix, tokens_wo_prefix)
|
||||
|
||||
@@ -546,6 +546,15 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
decoded_tokens = tokenizer.decode(input_ids)
|
||||
self.assertEqual(decoded_tokens, " <s> Hello<s> how")
|
||||
|
||||
def test_some_edge_cases(self):
|
||||
tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False)
|
||||
|
||||
sp_tokens = tokenizer.sp_model.encode("<s>>", out_type=str)
|
||||
self.assertEqual(sp_tokens, ["<", "s", ">>"])
|
||||
tokens = tokenizer.tokenize("<s>>")
|
||||
self.assertNotEqual(sp_tokens, tokens)
|
||||
self.assertEqual(tokens, ["<s>", ">"])
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
|
||||
@@ -128,11 +128,13 @@ class GetFromCacheTests(unittest.TestCase):
|
||||
|
||||
self.assertIsNone(get_file_from_repo(tmp_dir, "b.txt"))
|
||||
|
||||
@unittest.skip("Test is broken, fix me Wauplain!")
|
||||
def test_get_file_gated_repo(self):
|
||||
"""Test download file from a gated repo fails with correct message when not authenticated."""
|
||||
with self.assertRaisesRegex(EnvironmentError, "You are trying to access a gated repo."):
|
||||
cached_file(GATED_REPO, README_FILE, use_auth_token=False)
|
||||
|
||||
@unittest.skip("Test is broken, fix me Wauplain!")
|
||||
def test_has_file_gated_repo(self):
|
||||
"""Test check file existence from a gated repo fails with correct message when not authenticated."""
|
||||
with self.assertRaisesRegex(EnvironmentError, "is a gated repository"):
|
||||
|
||||
Reference in New Issue
Block a user