Flax Generate (#11777)
* fix_torch_device_generate_test * remove @ * add * indexing * correct a couple of tests * fix tests * add logits processor * finish top_k, top_p, temp * add docs * correct flax prng key default * improve generate * add generation docs * add docs * make style * revert model outputs change * make style * correct typo * fix tests * fix slow test * add raise * finish generation Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
2df546918e
commit
996a315e76
@@ -78,6 +78,9 @@ GreedySearchOutput
|
|||||||
.. autoclass:: transformers.generation_utils.GreedySearchEncoderDecoderOutput
|
.. autoclass:: transformers.generation_utils.GreedySearchEncoderDecoderOutput
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
.. autoclass:: transformers.generation_flax_utils.FlaxGreedySearchOutput
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
SampleOutput
|
SampleOutput
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
@@ -88,6 +91,9 @@ SampleOutput
|
|||||||
.. autoclass:: transformers.generation_utils.SampleEncoderDecoderOutput
|
.. autoclass:: transformers.generation_utils.SampleEncoderDecoderOutput
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
.. autoclass:: transformers.generation_flax_utils.FlaxSampleOutput
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
BeamSearchOutput
|
BeamSearchOutput
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
@@ -160,6 +166,24 @@ generation.
|
|||||||
.. autoclass:: transformers.InfNanRemoveLogitsProcessor
|
.. autoclass:: transformers.InfNanRemoveLogitsProcessor
|
||||||
:members: __call__
|
:members: __call__
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxLogitsProcessor
|
||||||
|
:members: __call__
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxLogitsProcessorList
|
||||||
|
:members: __call__
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxLogitsWarper
|
||||||
|
:members: __call__
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxTemperatureLogitsWarper
|
||||||
|
:members: __call__
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxTopPLogitsWarper
|
||||||
|
:members: __call__
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxTopKLogitsWarper
|
||||||
|
:members: __call__
|
||||||
|
|
||||||
|
|
||||||
StoppingCriteria
|
StoppingCriteria
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|||||||
@@ -26,8 +26,9 @@ are common among all the models to:
|
|||||||
|
|
||||||
The other methods that are common to each model are defined in :class:`~transformers.modeling_utils.ModuleUtilsMixin`
|
The other methods that are common to each model are defined in :class:`~transformers.modeling_utils.ModuleUtilsMixin`
|
||||||
(for the PyTorch models) and :class:`~transformers.modeling_tf_utils.TFModuleUtilsMixin` (for the TensorFlow models) or
|
(for the PyTorch models) and :class:`~transformers.modeling_tf_utils.TFModuleUtilsMixin` (for the TensorFlow models) or
|
||||||
for text generation, :class:`~transformers.generation_utils.GenerationMixin` (for the PyTorch models) and
|
for text generation, :class:`~transformers.generation_utils.GenerationMixin` (for the PyTorch models),
|
||||||
:class:`~transformers.generation_tf_utils.TFGenerationMixin` (for the TensorFlow models)
|
:class:`~transformers.generation_tf_utils.TFGenerationMixin` (for the TensorFlow models) and
|
||||||
|
:class:`~transformers.generation_flax_utils.FlaxGenerationMixin` (for the Flax/JAX models).
|
||||||
|
|
||||||
|
|
||||||
PreTrainedModel
|
PreTrainedModel
|
||||||
@@ -74,6 +75,9 @@ Generation
|
|||||||
.. autoclass:: transformers.generation_tf_utils.TFGenerationMixin
|
.. autoclass:: transformers.generation_tf_utils.TFGenerationMixin
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
.. autoclass:: transformers.generation_flax_utils.FlaxGenerationMixin
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
Pushing to the Hub
|
Pushing to the Hub
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|||||||
@@ -1437,6 +1437,14 @@ else:
|
|||||||
|
|
||||||
# FLAX-backed objects
|
# FLAX-backed objects
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
|
_import_structure["generation_flax_logits_process"] = [
|
||||||
|
"FlaxLogitsProcessor",
|
||||||
|
"FlaxLogitsProcessorList",
|
||||||
|
"FlaxLogitsWarper",
|
||||||
|
"FlaxTemperatureLogitsWarper",
|
||||||
|
"FlaxTopKLogitsWarper",
|
||||||
|
"FlaxTopPLogitsWarper",
|
||||||
|
]
|
||||||
_import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
|
_import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
|
||||||
_import_structure["models.auto"].extend(
|
_import_structure["models.auto"].extend(
|
||||||
[
|
[
|
||||||
@@ -2693,6 +2701,14 @@ if TYPE_CHECKING:
|
|||||||
from .utils.dummy_tf_objects import *
|
from .utils.dummy_tf_objects import *
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
|
from .generation_flax_logits_process import (
|
||||||
|
FlaxLogitsProcessor,
|
||||||
|
FlaxLogitsProcessorList,
|
||||||
|
FlaxLogitsWarper,
|
||||||
|
FlaxTemperatureLogitsWarper,
|
||||||
|
FlaxTopKLogitsWarper,
|
||||||
|
FlaxTopPLogitsWarper,
|
||||||
|
)
|
||||||
from .modeling_flax_utils import FlaxPreTrainedModel
|
from .modeling_flax_utils import FlaxPreTrainedModel
|
||||||
from .models.auto import (
|
from .models.auto import (
|
||||||
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||||
|
|||||||
192
src/transformers/generation_flax_logits_process.py
Normal file
192
src/transformers/generation_flax_logits_process.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 The HuggingFace Inc. team
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from abc import ABC
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.lax as lax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jaxlib.xla_extension as jax_xla
|
||||||
|
|
||||||
|
from .file_utils import add_start_docstrings
|
||||||
|
from .utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
|
||||||
|
Args:
|
||||||
|
input_ids (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
|
Indices of input sequence tokens in the vocabulary.
|
||||||
|
|
||||||
|
Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
|
||||||
|
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||||
|
details.
|
||||||
|
|
||||||
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
|
scores (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.vocab_size)`):
|
||||||
|
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
|
||||||
|
search or log softmax for each vocabulary token when using beam search
|
||||||
|
kwargs:
|
||||||
|
Additional logits processor specific kwargs.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxLogitsProcessor(ABC):
|
||||||
|
"""Abstract base class for all logit processors that can be applied during generation."""
|
||||||
|
|
||||||
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
|
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
|
||||||
|
"""Flax method for processing logits."""
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxLogitsWarper(ABC):
|
||||||
|
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
|
||||||
|
|
||||||
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
|
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
|
||||||
|
"""Flax method for warping logits."""
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxLogitsProcessorList(list):
|
||||||
|
"""
|
||||||
|
This class can be used to create a list of :class:`~transformers.FlaxLogitsProcessor` or
|
||||||
|
:class:`~transformers.FlaxLogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits
|
||||||
|
from list and adds a specific `__call__` method to apply each :class:`~transformers.FlaxLogitsProcessor` or
|
||||||
|
:class:`~transformers.FlaxLogitsWarper` to the inputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
|
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, **kwargs) -> jax_xla.DeviceArray:
|
||||||
|
for processor in self:
|
||||||
|
function_args = inspect.signature(processor.__call__).parameters
|
||||||
|
if len(function_args) > 2:
|
||||||
|
assert all(
|
||||||
|
arg in kwargs for arg in list(function_args.keys())[2:]
|
||||||
|
), f"Make sure that all the required parameters: {list(function_args.keys())} for {processor.__class__} are passed to the logits processor."
|
||||||
|
scores = processor(input_ids, scores, **kwargs)
|
||||||
|
else:
|
||||||
|
scores = processor(input_ids, scores)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
|
||||||
|
r"""
|
||||||
|
:class:`transformers.LogitsWarper` for temperature (exponential scaling output probability distribution).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
temperature (:obj:`float`):
|
||||||
|
The value used to module the logits distribution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, temperature: float):
|
||||||
|
if not isinstance(temperature, float) or not (temperature > 0):
|
||||||
|
raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
|
||||||
|
|
||||||
|
self.temperature = temperature
|
||||||
|
|
||||||
|
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
|
||||||
|
scores = scores / self.temperature
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxTopPLogitsWarper(FlaxLogitsWarper):
|
||||||
|
"""
|
||||||
|
:class:`transformers.LogitsWarper` that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <=
|
||||||
|
prob_cut_off.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
top_p (:obj:`float`):
|
||||||
|
If set to < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are
|
||||||
|
kept for generation.
|
||||||
|
filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`):
|
||||||
|
All filtered values will be set to this float value.
|
||||||
|
min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1):
|
||||||
|
Minimum number of tokens that cannot be filtered.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||||
|
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
|
||||||
|
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
|
||||||
|
|
||||||
|
self.top_p = top_p
|
||||||
|
self.filter_value = filter_value
|
||||||
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
|
||||||
|
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
|
||||||
|
topk_scores, topk_indices = lax.top_k(scores, scores.shape[-1])
|
||||||
|
|
||||||
|
mask_scores = jnp.full_like(scores, self.filter_value)
|
||||||
|
cumulative_probs = jax.nn.softmax(topk_scores, axis=-1).cumsum(axis=-1)
|
||||||
|
score_mask = cumulative_probs < self.top_p
|
||||||
|
|
||||||
|
# include the token that is higher than top_p as well
|
||||||
|
score_mask |= jax.ops.index_update(jnp.roll(score_mask, 1), jax.ops.index[:, 0], True)
|
||||||
|
|
||||||
|
# min tokens to keep
|
||||||
|
score_mask = jax.ops.index_update(score_mask, jax.ops.index[:, : self.min_tokens_to_keep], True)
|
||||||
|
|
||||||
|
topk_next_scores = jnp.where(score_mask, topk_scores, mask_scores)
|
||||||
|
next_scores = jax.lax.sort_key_val(topk_indices, topk_next_scores)[-1]
|
||||||
|
|
||||||
|
return next_scores
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxTopKLogitsWarper(FlaxLogitsWarper):
|
||||||
|
r"""
|
||||||
|
:class:`transformers.LogitsWarper` that performs top-k, i.e. restricting to the k highest probability elements.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
top_k (:obj:`int`):
|
||||||
|
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||||
|
filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`):
|
||||||
|
All filtered values will be set to this float value.
|
||||||
|
min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1):
|
||||||
|
Minimum number of tokens that cannot be filtered.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||||
|
if not isinstance(top_k, int) or top_k <= 0:
|
||||||
|
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
|
||||||
|
|
||||||
|
self.top_k = top_k
|
||||||
|
self.filter_value = filter_value
|
||||||
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
|
||||||
|
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
|
||||||
|
batch_size, vocab_size = scores.shape
|
||||||
|
next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value)
|
||||||
|
|
||||||
|
topk = min(max(self.top_k, self.min_tokens_to_keep), scores.shape[-1]) # Safety check
|
||||||
|
topk_scores, topk_indices = lax.top_k(scores, topk)
|
||||||
|
shift = jnp.broadcast_to((jnp.arange(batch_size) * vocab_size)[:, None], (batch_size, topk)).flatten()
|
||||||
|
topk_scores_flat = topk_scores.flatten()
|
||||||
|
topk_indices_flat = topk_indices.flatten() + shift
|
||||||
|
|
||||||
|
next_scores_flat = jax.ops.index_update(next_scores_flat, topk_indices_flat, topk_scores_flat)
|
||||||
|
next_scores = next_scores_flat.reshape(batch_size, vocab_size)
|
||||||
|
return next_scores
|
||||||
388
src/transformers/generation_flax_utils.py
Normal file
388
src/transformers/generation_flax_utils.py
Normal file
@@ -0,0 +1,388 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 The Google AI Flax Team Authors, and The HuggingFace Inc. team.
|
||||||
|
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import flax
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jaxlib.xla_extension as jax_xla
|
||||||
|
from jax import lax
|
||||||
|
|
||||||
|
from .file_utils import ModelOutput
|
||||||
|
from .generation_flax_logits_process import (
|
||||||
|
FlaxLogitsProcessorList,
|
||||||
|
FlaxTemperatureLogitsWarper,
|
||||||
|
FlaxTopKLogitsWarper,
|
||||||
|
FlaxTopPLogitsWarper,
|
||||||
|
)
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@flax.struct.dataclass
|
||||||
|
class FlaxGreedySearchOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Flax Base class for outputs of decoder-only generation models using greedy search.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||||
|
The generated sequences. If all batches finished early due to the :obj:`eos_token_id`, :obj:`sequences` is
|
||||||
|
padded to :obj:`max_length`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sequences: jax_xla.DeviceArray = None
|
||||||
|
|
||||||
|
|
||||||
|
@flax.struct.dataclass
|
||||||
|
class FlaxSampleOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Flax Base class for outputs of decoder-only generation models using sampling.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size, max_length)`):
|
||||||
|
The generated sequences. If all batches finished early due to the :obj:`eos_token_id`, :obj:`sequences` is
|
||||||
|
padded to :obj:`max_length`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sequences: jax_xla.DeviceArray = None
|
||||||
|
|
||||||
|
|
||||||
|
@flax.struct.dataclass
|
||||||
|
class GreedyState:
|
||||||
|
cur_len: jax_xla.DeviceArray
|
||||||
|
sequences: jax_xla.DeviceArray
|
||||||
|
current_token: jax_xla.DeviceArray
|
||||||
|
is_sent_finished: jax_xla.DeviceArray
|
||||||
|
model_kwargs: Dict[str, jax_xla.DeviceArray]
|
||||||
|
|
||||||
|
|
||||||
|
@flax.struct.dataclass
|
||||||
|
class SampleState:
|
||||||
|
cur_len: jax_xla.DeviceArray
|
||||||
|
sequences: jax_xla.DeviceArray
|
||||||
|
current_token: jax_xla.DeviceArray
|
||||||
|
is_sent_finished: jax_xla.DeviceArray
|
||||||
|
prng_key: jax_xla.DeviceArray
|
||||||
|
model_kwargs: Dict[str, jax_xla.DeviceArray]
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxGenerationMixin:
|
||||||
|
"""
|
||||||
|
A class containing all of the functions supporting generation, to be used as a mixin in
|
||||||
|
:class:`~transformers.FlaxPreTrainedModel`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _run_loop_in_debug(cond_fn, body_fn, init_state):
|
||||||
|
"""
|
||||||
|
Run generation in untraced mode. This should only be used for debugging purposes.
|
||||||
|
"""
|
||||||
|
state = init_state
|
||||||
|
while cond_fn(state):
|
||||||
|
state = body_fn(state)
|
||||||
|
return state
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
input_ids: jax_xla.DeviceArray,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
pad_token_id: Optional[int] = None,
|
||||||
|
eos_token_id: Optional[int] = None,
|
||||||
|
do_sample: Optional[bool] = None,
|
||||||
|
prng_key: Optional[jax_xla.DeviceArray] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
trace: bool = True,
|
||||||
|
**model_kwargs,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
|
||||||
|
and, multinomial sampling.
|
||||||
|
|
||||||
|
Apart from :obj:`input_ids`, all the arguments below will default to the value of the attribute of the same
|
||||||
|
name inside the :class:`~transformers.PretrainedConfig` of the model. The default values indicated are the
|
||||||
|
default values of those config.
|
||||||
|
|
||||||
|
Most of these parameters are explained in more detail in `this blog post
|
||||||
|
<https://huggingface.co/blog/how-to-generate>`__.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
|
||||||
|
input_ids (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
The sequence used as a prompt for the generation.
|
||||||
|
max_length (:obj:`int`, `optional`, defaults to 20):
|
||||||
|
The maximum length of the sequence to be generated.
|
||||||
|
do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to use sampling ; use greedy decoding otherwise.
|
||||||
|
temperature (:obj:`float`, `optional`, defaults to 1.0):
|
||||||
|
The value used to module the next token probabilities.
|
||||||
|
top_k (:obj:`int`, `optional`, defaults to 50):
|
||||||
|
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||||
|
top_p (:obj:`float`, `optional`, defaults to 1.0):
|
||||||
|
If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or
|
||||||
|
higher are kept for generation.
|
||||||
|
pad_token_id (:obj:`int`, `optional`):
|
||||||
|
The id of the `padding` token.
|
||||||
|
bos_token_id (:obj:`int`, `optional`):
|
||||||
|
The id of the `beginning-of-sequence` token.
|
||||||
|
eos_token_id (:obj:`int`, `optional`):
|
||||||
|
The id of the `end-of-sequence` token.
|
||||||
|
trace (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to
|
||||||
|
a considerably slower runtime.
|
||||||
|
model_kwargs:
|
||||||
|
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:class:`~transformers.file_utils.ModelOutput`.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
>>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
|
||||||
|
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
|
||||||
|
>>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2")
|
||||||
|
>>> input_context = "The dog"
|
||||||
|
>>> # encode input context
|
||||||
|
>>> input_ids = tokenizer(input_context, return_tensors="jax").input_ids
|
||||||
|
>>> # generate candidates using sampling
|
||||||
|
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
|
||||||
|
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||||
|
"""
|
||||||
|
# set init values
|
||||||
|
max_length = max_length if max_length is not None else self.config.max_length
|
||||||
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||||
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||||
|
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
|
||||||
|
|
||||||
|
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||||
|
|
||||||
|
if do_sample:
|
||||||
|
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature)
|
||||||
|
return self._sample(
|
||||||
|
input_ids,
|
||||||
|
max_length,
|
||||||
|
pad_token_id,
|
||||||
|
eos_token_id,
|
||||||
|
prng_key,
|
||||||
|
logits_warper=logits_warper,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
trace=trace,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self._greedy_search(
|
||||||
|
input_ids, max_length, pad_token_id, eos_token_id, trace=trace, model_kwargs=model_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_logits_warper(
|
||||||
|
self, top_k: int = None, top_p: float = None, temperature: float = None
|
||||||
|
) -> FlaxLogitsProcessorList:
|
||||||
|
"""
|
||||||
|
This class returns a :obj:`~transformers.FlaxLogitsProcessorList` list object that contains all relevant
|
||||||
|
:obj:`~transformers.FlaxLogitsWarper` instances used for multinomial sampling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# init warp parameters
|
||||||
|
top_k = top_k if top_k is not None else self.config.top_k
|
||||||
|
top_p = top_p if top_p is not None else self.config.top_p
|
||||||
|
temperature = temperature if temperature is not None else self.config.temperature
|
||||||
|
# instantiate warpers list
|
||||||
|
warpers = FlaxLogitsProcessorList()
|
||||||
|
|
||||||
|
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||||
|
# all samplers can be found in `generation_utils_samplers.py`
|
||||||
|
if temperature is not None and temperature != 1.0:
|
||||||
|
warpers.append(FlaxTemperatureLogitsWarper(temperature))
|
||||||
|
if top_k is not None and top_k != 0:
|
||||||
|
warpers.append(FlaxTopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1))
|
||||||
|
if top_p is not None and top_p < 1.0:
|
||||||
|
warpers.append(FlaxTopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1))
|
||||||
|
|
||||||
|
return warpers
|
||||||
|
|
||||||
|
def _greedy_search(
|
||||||
|
self,
|
||||||
|
input_ids: None,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
pad_token_id: Optional[int] = None,
|
||||||
|
eos_token_id: Optional[int] = None,
|
||||||
|
trace: bool = True,
|
||||||
|
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
||||||
|
):
|
||||||
|
# init values
|
||||||
|
max_length = max_length if max_length is not None else self.config.max_length
|
||||||
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||||
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||||
|
|
||||||
|
batch_size, cur_len = input_ids.shape
|
||||||
|
|
||||||
|
eos_token_id = jnp.array(eos_token_id)
|
||||||
|
pad_token_id = jnp.array(pad_token_id)
|
||||||
|
cur_len = jnp.array(cur_len)
|
||||||
|
|
||||||
|
# per batch-item holding current token in loop.
|
||||||
|
sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
|
||||||
|
sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
|
||||||
|
|
||||||
|
# per batch-item state bit indicating if sentence has finished.
|
||||||
|
is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
|
||||||
|
|
||||||
|
model = self
|
||||||
|
|
||||||
|
# initialize model specific kwargs
|
||||||
|
model_kwargs = model.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
|
||||||
|
|
||||||
|
# initialize state
|
||||||
|
state = GreedyState(
|
||||||
|
cur_len=cur_len,
|
||||||
|
sequences=sequences,
|
||||||
|
current_token=input_ids,
|
||||||
|
is_sent_finished=is_sent_finished,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def greedy_search_cond_fn(state):
|
||||||
|
"""state termination condition fn."""
|
||||||
|
has_reached_max_length = state.cur_len == max_length
|
||||||
|
all_sequence_finished = jnp.all(state.is_sent_finished)
|
||||||
|
finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished)
|
||||||
|
return ~finish_generation
|
||||||
|
|
||||||
|
def greedy_search_body_fn(state):
|
||||||
|
"""state update fn."""
|
||||||
|
model_outputs = model(state.current_token, **state.model_kwargs)
|
||||||
|
next_token = jnp.argmax(model_outputs.logits[:, -1], axis=-1)
|
||||||
|
|
||||||
|
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
|
||||||
|
next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished
|
||||||
|
next_token = next_token[:, None]
|
||||||
|
|
||||||
|
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
|
||||||
|
next_model_kwargs = model.update_inputs_for_generation(model_outputs, model_kwargs)
|
||||||
|
|
||||||
|
return GreedyState(
|
||||||
|
cur_len=state.cur_len + 1,
|
||||||
|
sequences=next_sequences,
|
||||||
|
current_token=next_token,
|
||||||
|
is_sent_finished=next_is_sent_finished,
|
||||||
|
model_kwargs=next_model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
|
||||||
|
state = greedy_search_body_fn(state)
|
||||||
|
|
||||||
|
if not trace:
|
||||||
|
state = self._run_loop_in_debug(greedy_search_cond_fn, greedy_search_body_fn, state)
|
||||||
|
else:
|
||||||
|
state = lax.while_loop(greedy_search_cond_fn, greedy_search_body_fn, state)
|
||||||
|
|
||||||
|
return FlaxGreedySearchOutput(sequences=state.sequences)
|
||||||
|
|
||||||
|
def _sample(
|
||||||
|
self,
|
||||||
|
input_ids: None,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
pad_token_id: Optional[int] = None,
|
||||||
|
eos_token_id: Optional[int] = None,
|
||||||
|
prng_key: Optional[jax_xla.DeviceArray] = None,
|
||||||
|
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
||||||
|
logits_warper: Optional[FlaxLogitsProcessorList] = None,
|
||||||
|
trace: bool = True,
|
||||||
|
):
|
||||||
|
# init values
|
||||||
|
max_length = max_length if max_length is not None else self.config.max_length
|
||||||
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||||
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||||
|
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
|
||||||
|
|
||||||
|
batch_size, cur_len = input_ids.shape
|
||||||
|
|
||||||
|
eos_token_id = jnp.array(eos_token_id)
|
||||||
|
pad_token_id = jnp.array(pad_token_id)
|
||||||
|
cur_len = jnp.array(cur_len)
|
||||||
|
|
||||||
|
# per batch-item holding current token in loop.
|
||||||
|
sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
|
||||||
|
sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
|
||||||
|
|
||||||
|
# per batch-item state bit indicating if sentence has finished.
|
||||||
|
is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
|
||||||
|
|
||||||
|
model = self
|
||||||
|
|
||||||
|
# initialize model specific kwargs
|
||||||
|
model_kwargs = model.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
|
||||||
|
|
||||||
|
# initialize state
|
||||||
|
state = SampleState(
|
||||||
|
cur_len=cur_len,
|
||||||
|
sequences=sequences,
|
||||||
|
current_token=input_ids,
|
||||||
|
is_sent_finished=is_sent_finished,
|
||||||
|
prng_key=prng_key,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def sample_search_cond_fn(state):
|
||||||
|
"""state termination condition fn."""
|
||||||
|
has_reached_max_length = state.cur_len == max_length
|
||||||
|
all_sequence_finished = jnp.all(state.is_sent_finished)
|
||||||
|
finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished)
|
||||||
|
return ~finish_generation
|
||||||
|
|
||||||
|
def sample_search_body_fn(state):
|
||||||
|
"""state update fn."""
|
||||||
|
prng_key, prng_key_next = jax.random.split(state.prng_key)
|
||||||
|
model_outputs = model(state.current_token, **state.model_kwargs)
|
||||||
|
|
||||||
|
logits = model_outputs.logits[:, -1]
|
||||||
|
|
||||||
|
# apply top_k, top_k, temperature
|
||||||
|
logits = logits_warper(state.sequences, logits)
|
||||||
|
|
||||||
|
next_token = jax.random.categorical(prng_key, model_outputs.logits[:, -1], axis=-1)
|
||||||
|
|
||||||
|
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
|
||||||
|
next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished
|
||||||
|
next_token = next_token[:, None]
|
||||||
|
|
||||||
|
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
|
||||||
|
next_model_kwargs = model.update_inputs_for_generation(model_outputs, model_kwargs)
|
||||||
|
|
||||||
|
return SampleState(
|
||||||
|
cur_len=state.cur_len + 1,
|
||||||
|
sequences=next_sequences,
|
||||||
|
current_token=next_token,
|
||||||
|
is_sent_finished=next_is_sent_finished,
|
||||||
|
model_kwargs=next_model_kwargs,
|
||||||
|
prng_key=prng_key_next,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
|
||||||
|
state = sample_search_body_fn(state)
|
||||||
|
|
||||||
|
if not trace:
|
||||||
|
state = self._run_loop_in_debug(sample_search_cond_fn, sample_search_body_fn, state)
|
||||||
|
else:
|
||||||
|
state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)
|
||||||
|
|
||||||
|
return FlaxSampleOutput(sequences=state.sequences)
|
||||||
@@ -41,6 +41,7 @@ from .file_utils import (
|
|||||||
is_remote_url,
|
is_remote_url,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
|
from .generation_flax_utils import FlaxGenerationMixin
|
||||||
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
|
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
|
||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
@@ -57,7 +58,7 @@ ACT2FN = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class FlaxPreTrainedModel(PushToHubMixin):
|
class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||||
r"""
|
r"""
|
||||||
Base class for all models.
|
Base class for all models.
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ import jax
|
|||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||||
from flax.linen import combine_masks, dot_product_attention, make_causal_mask
|
from flax.linen import combine_masks, dot_product_attention, make_causal_mask
|
||||||
from flax.traverse_util import flatten_dict
|
|
||||||
from jax import lax
|
from jax import lax
|
||||||
|
|
||||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||||
@@ -322,13 +321,6 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
|
|||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||||
|
|
||||||
@property
|
|
||||||
def _attn_layer_name(self):
|
|
||||||
attn_layer_key_tuple = ("h", "0", "attn")
|
|
||||||
if self.base_model_prefix in set(self.params.keys()):
|
|
||||||
attn_layer_key_tuple = (self.base_model_prefix,) + attn_layer_key_tuple
|
|
||||||
return attn_layer_key_tuple
|
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
@@ -381,28 +373,13 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
|
|||||||
batch_size, sequence_length = input_ids.shape
|
batch_size, sequence_length = input_ids.shape
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
if past_key_values is not None and input_ids.shape[-1] == 1:
|
if past_key_values is not None:
|
||||||
# if `past_key_values` are passed and input_ids are longer than 1, we are in cached auto-regressive generation. It has to be made sure that position_ids are set correctly
|
raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
|
||||||
cache_shift = flatten_dict(unfreeze(past_key_values))[self._attn_layer_name + ("cache_index",)]
|
|
||||||
position_ids = jnp.broadcast_to(
|
|
||||||
jnp.arange(self.config.max_position_embeddings)[None, :],
|
|
||||||
(batch_size, self.config.max_position_embeddings),
|
|
||||||
)
|
|
||||||
position_ids = lax.dynamic_slice(position_ids, (0, cache_shift), (batch_size, 1))
|
|
||||||
else:
|
|
||||||
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
|
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
# if past_key_values are passed we need to create an attention_mask of the same length as `cache_length`
|
attention_mask = jnp.ones((batch_size, sequence_length))
|
||||||
if past_key_values is not None:
|
|
||||||
cache_length = flatten_dict(unfreeze(past_key_values))[self._attn_layer_name + ("cached_key",)].shape[
|
|
||||||
1
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
cache_length = sequence_length
|
|
||||||
|
|
||||||
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. But since GPT2 uses a causal mask, those positions are masked anyways. Thus we can create a single static attention_mask here, which is more efficient for compilation
|
|
||||||
attention_mask = jnp.ones((batch_size, cache_length))
|
|
||||||
|
|
||||||
# Handle any PRNG if needed
|
# Handle any PRNG if needed
|
||||||
rngs = {}
|
rngs = {}
|
||||||
@@ -627,6 +604,32 @@ class FlaxGPT2LMHeadModule(nn.Module):
|
|||||||
class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel):
|
class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel):
|
||||||
module_class = FlaxGPT2LMHeadModule
|
module_class = FlaxGPT2LMHeadModule
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
||||||
|
# initializing the cache
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
|
||||||
|
past_key_values = self.init_cache(batch_size, max_length)
|
||||||
|
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
||||||
|
# But since GPT2 uses a causal mask, those positions are masked anyways.
|
||||||
|
# Thus we can create a single static attention_mask here, which is more efficient for compilation
|
||||||
|
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
||||||
|
if attention_mask is not None:
|
||||||
|
position_ids = attention_mask.cumsum(axis=-1) - 1
|
||||||
|
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
|
||||||
|
else:
|
||||||
|
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"past_key_values": past_key_values,
|
||||||
|
"attention_mask": extended_attention_mask,
|
||||||
|
"position_ids": position_ids,
|
||||||
|
}
|
||||||
|
|
||||||
|
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
||||||
|
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
||||||
|
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
|
||||||
|
return model_kwargs
|
||||||
|
|
||||||
|
|
||||||
append_call_sample_docstring(
|
append_call_sample_docstring(
|
||||||
FlaxGPT2LMHeadModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC
|
FlaxGPT2LMHeadModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC
|
||||||
|
|||||||
@@ -2,6 +2,36 @@
|
|||||||
from ..file_utils import requires_backends
|
from ..file_utils import requires_backends
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxLogitsProcessor:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxLogitsProcessorList:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxLogitsWarper:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxTemperatureLogitsWarper:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxTopKLogitsWarper:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxTopPLogitsWarper:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
class FlaxPreTrainedModel:
|
class FlaxPreTrainedModel:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["flax"])
|
requires_backends(self, ["flax"])
|
||||||
|
|||||||
163
tests/test_generation_flax_logits_process.py
Normal file
163
tests/test_generation_flax_logits_process.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 The HuggingFace Team Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a clone of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers import is_flax_available
|
||||||
|
from transformers.testing_utils import require_flax
|
||||||
|
|
||||||
|
from .test_modeling_flax_common import ids_tensor
|
||||||
|
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from transformers.generation_flax_logits_process import (
|
||||||
|
FlaxLogitsProcessorList,
|
||||||
|
FlaxTemperatureLogitsWarper,
|
||||||
|
FlaxTopKLogitsWarper,
|
||||||
|
FlaxTopPLogitsWarper,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@require_flax
|
||||||
|
class LogitsProcessorTest(unittest.TestCase):
|
||||||
|
def _get_uniform_logits(self, batch_size: int, length: int):
|
||||||
|
scores = np.ones((batch_size, length)) / length
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def test_temperature_dist_warper(self):
|
||||||
|
input_ids = None
|
||||||
|
length = 20
|
||||||
|
|
||||||
|
scores = self._get_uniform_logits(batch_size=2, length=length)
|
||||||
|
|
||||||
|
# tweak scores to not be uniform anymore
|
||||||
|
scores[1, 5] = (1 / length) + 0.1 # peak, 1st batch
|
||||||
|
scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch
|
||||||
|
|
||||||
|
# compute softmax
|
||||||
|
probs = jax.nn.softmax(scores, axis=-1)
|
||||||
|
|
||||||
|
temp_dist_warper_sharper = FlaxTemperatureLogitsWarper(temperature=0.5)
|
||||||
|
temp_dist_warper_smoother = FlaxTemperatureLogitsWarper(temperature=1.3)
|
||||||
|
|
||||||
|
warped_prob_sharp = jax.nn.softmax(temp_dist_warper_sharper(input_ids, scores.copy()), axis=-1)
|
||||||
|
warped_prob_smooth = jax.nn.softmax(temp_dist_warper_smoother(input_ids, scores.copy()), axis=-1)
|
||||||
|
|
||||||
|
# uniform distribution stays uniform
|
||||||
|
self.assertTrue(jnp.allclose(probs[0, :], warped_prob_sharp[0, :], atol=1e-3))
|
||||||
|
self.assertTrue(jnp.allclose(probs[0, :], warped_prob_smooth[0, :], atol=1e-3))
|
||||||
|
|
||||||
|
# sharp peaks get higher, valleys get lower
|
||||||
|
self.assertLess(probs[1, :].max(), warped_prob_sharp[1, :].max())
|
||||||
|
self.assertGreater(probs[1, :].min(), warped_prob_sharp[1, :].min())
|
||||||
|
|
||||||
|
# smooth peaks get lower, valleys get higher
|
||||||
|
self.assertGreater(probs[1, :].max(), warped_prob_smooth[1, :].max())
|
||||||
|
self.assertLess(probs[1, :].min(), warped_prob_smooth[1, :].min())
|
||||||
|
|
||||||
|
def test_top_k_dist_warper(self):
|
||||||
|
input_ids = None
|
||||||
|
vocab_size = 10
|
||||||
|
batch_size = 2
|
||||||
|
|
||||||
|
# create ramp distribution
|
||||||
|
ramp_logits = np.broadcast_to(np.arange(vocab_size)[None, :], (batch_size, vocab_size)).copy()
|
||||||
|
ramp_logits[1:, : vocab_size // 2] = ramp_logits[1:, : vocab_size // 2] + vocab_size
|
||||||
|
|
||||||
|
top_k_warp = FlaxTopKLogitsWarper(3)
|
||||||
|
|
||||||
|
scores = top_k_warp(input_ids, ramp_logits)
|
||||||
|
|
||||||
|
# check that correct tokens are filtered
|
||||||
|
self.assertListEqual(jnp.isinf(scores[0]).tolist(), 7 * [True] + 3 * [False])
|
||||||
|
self.assertListEqual(jnp.isinf(scores[1]).tolist(), 2 * [True] + 3 * [False] + 5 * [True])
|
||||||
|
|
||||||
|
# check special case
|
||||||
|
length = 5
|
||||||
|
top_k_warp_safety_check = FlaxTopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3)
|
||||||
|
|
||||||
|
ramp_logits = np.broadcast_to(np.arange(length)[None, :], (batch_size, length)).copy()
|
||||||
|
scores = top_k_warp_safety_check(input_ids, ramp_logits)
|
||||||
|
|
||||||
|
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
|
||||||
|
self.assertListEqual((scores == 0.0).sum(axis=-1).tolist(), [2, 2])
|
||||||
|
|
||||||
|
def test_top_p_dist_warper(self):
|
||||||
|
input_ids = None
|
||||||
|
vocab_size = 10
|
||||||
|
batch_size = 2
|
||||||
|
|
||||||
|
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
|
||||||
|
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]]))
|
||||||
|
|
||||||
|
top_p_warp = FlaxTopPLogitsWarper(0.7)
|
||||||
|
filtered_dist = np.exp(top_p_warp(input_ids, dist))
|
||||||
|
|
||||||
|
# dist should be filtered to keep min num values so that sum is >= 0.7
|
||||||
|
# exp (-inf) => 0
|
||||||
|
EXPECTED_FILTERED_DIST = np.array([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]])
|
||||||
|
self.assertTrue(np.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
|
||||||
|
|
||||||
|
# check edge cases with negative and extreme logits
|
||||||
|
ramp_logits = np.broadcast_to(np.arange(vocab_size)[None, :], (batch_size, vocab_size)).copy() - (
|
||||||
|
vocab_size // 2
|
||||||
|
)
|
||||||
|
|
||||||
|
# make ramp_logits more extreme
|
||||||
|
ramp_logits[1] = ramp_logits[1] * 100.0
|
||||||
|
|
||||||
|
# make sure at least 2 tokens are kept
|
||||||
|
top_p_warp = FlaxTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
|
||||||
|
filtered_dist = top_p_warp(input_ids, ramp_logits)
|
||||||
|
|
||||||
|
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
|
||||||
|
self.assertListEqual((filtered_dist != 0.0).sum(axis=-1).tolist(), [3, 2])
|
||||||
|
|
||||||
|
def test_processor_list(self):
|
||||||
|
batch_size = 4
|
||||||
|
sequence_length = 10
|
||||||
|
vocab_size = 15
|
||||||
|
|
||||||
|
# dummy input_ids and scores
|
||||||
|
input_ids = ids_tensor((batch_size, sequence_length), vocab_size)
|
||||||
|
input_ids_comp = input_ids.copy()
|
||||||
|
|
||||||
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
|
scores_comp = scores.copy()
|
||||||
|
|
||||||
|
# instantiate all dist processors
|
||||||
|
temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5)
|
||||||
|
top_k_warp = FlaxTopKLogitsWarper(3)
|
||||||
|
top_p_warp = FlaxTopPLogitsWarper(0.8)
|
||||||
|
|
||||||
|
# no processor list
|
||||||
|
scores = temp_dist_warp(input_ids, scores)
|
||||||
|
scores = top_k_warp(input_ids, scores)
|
||||||
|
scores = top_p_warp(input_ids, scores)
|
||||||
|
|
||||||
|
# with processor list
|
||||||
|
processor = FlaxLogitsProcessorList([temp_dist_warp, top_k_warp, top_p_warp])
|
||||||
|
scores_comp = processor(input_ids, scores_comp)
|
||||||
|
|
||||||
|
# scores should be equal
|
||||||
|
self.assertTrue(jnp.allclose(scores, scores_comp, atol=1e-3))
|
||||||
|
|
||||||
|
# input_ids should never be changed
|
||||||
|
self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())
|
||||||
170
tests/test_generation_flax_utils.py
Normal file
170
tests/test_generation_flax_utils.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from transformers import is_flax_available
|
||||||
|
from transformers.testing_utils import require_flax
|
||||||
|
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
import os
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax import jit
|
||||||
|
|
||||||
|
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||||
|
|
||||||
|
|
||||||
|
def ids_tensor(shape, vocab_size, rng=None):
|
||||||
|
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||||
|
if rng is None:
|
||||||
|
rng = random.Random()
|
||||||
|
|
||||||
|
total_dims = 1
|
||||||
|
for dim in shape:
|
||||||
|
total_dims *= dim
|
||||||
|
|
||||||
|
values = []
|
||||||
|
for _ in range(total_dims):
|
||||||
|
values.append(rng.randint(0, vocab_size - 1))
|
||||||
|
|
||||||
|
output = np.array(values, dtype=jnp.int32).reshape(shape)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def random_attention_mask(shape, rng=None):
|
||||||
|
attn_mask = ids_tensor(shape, vocab_size=2, rng=rng)
|
||||||
|
# make sure that at least one token is attended to for each batch
|
||||||
|
attn_mask[:, -1] = 1
|
||||||
|
return attn_mask
|
||||||
|
|
||||||
|
|
||||||
|
@require_flax
|
||||||
|
class FlaxGenerationTesterMixin:
|
||||||
|
model_tester = None
|
||||||
|
all_generative_model_classes = ()
|
||||||
|
|
||||||
|
def _get_input_ids_and_config(self):
|
||||||
|
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
# cut to half length & take max batch_size 3
|
||||||
|
max_batch_size = 2
|
||||||
|
sequence_length = inputs["input_ids"].shape[-1] // 2
|
||||||
|
input_ids = inputs["input_ids"][:max_batch_size, :sequence_length]
|
||||||
|
|
||||||
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
|
attention_mask = attention_mask[:max_batch_size, :sequence_length]
|
||||||
|
|
||||||
|
# generate max 5 tokens
|
||||||
|
max_length = input_ids.shape[-1] + 5
|
||||||
|
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||||
|
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||||
|
config.pad_token_id = config.eos_token_id
|
||||||
|
return config, input_ids, attention_mask, max_length
|
||||||
|
|
||||||
|
def test_greedy_generate(self):
|
||||||
|
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||||
|
config.do_sample = False
|
||||||
|
config.max_length = max_length
|
||||||
|
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
generation_outputs = model.generate(input_ids).sequences
|
||||||
|
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||||
|
|
||||||
|
jit_generate = jit(model.generate)
|
||||||
|
jit_generation_outputs = jit_generate(input_ids).sequences
|
||||||
|
|
||||||
|
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||||
|
|
||||||
|
def test_sample_generate(self):
|
||||||
|
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||||
|
config.do_sample = True
|
||||||
|
config.max_length = max_length
|
||||||
|
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
generation_outputs = model.generate(input_ids).sequences
|
||||||
|
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||||
|
|
||||||
|
jit_generate = jit(model.generate)
|
||||||
|
jit_generation_outputs = jit_generate(input_ids).sequences
|
||||||
|
|
||||||
|
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||||
|
|
||||||
|
def test_sample_generate_logits_warper(self):
|
||||||
|
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||||
|
config.do_sample = True
|
||||||
|
config.max_length = max_length
|
||||||
|
config.temperature = 0.8
|
||||||
|
config.top_k = 10
|
||||||
|
config.top_p = 0.3
|
||||||
|
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
generation_outputs = model.generate(input_ids).sequences
|
||||||
|
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||||
|
|
||||||
|
jit_generate = jit(model.generate)
|
||||||
|
jit_generation_outputs = jit_generate(input_ids).sequences
|
||||||
|
|
||||||
|
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||||
|
|
||||||
|
def test_greedy_generate_attn_mask(self):
|
||||||
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||||
|
|
||||||
|
# pad attention mask on the left
|
||||||
|
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
|
||||||
|
|
||||||
|
config.do_sample = False
|
||||||
|
config.max_length = max_length
|
||||||
|
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
generation_outputs = model.generate(input_ids, attention_mask=attention_mask).sequences
|
||||||
|
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||||
|
|
||||||
|
jit_generate = jit(model.generate)
|
||||||
|
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
|
||||||
|
|
||||||
|
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||||
|
|
||||||
|
def test_sample_generate_attn_mask(self):
|
||||||
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||||
|
|
||||||
|
# pad attention mask on the left
|
||||||
|
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
|
||||||
|
|
||||||
|
config.do_sample = True
|
||||||
|
config.max_length = max_length
|
||||||
|
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
generation_outputs = model.generate(input_ids, attention_mask=attention_mask).sequences
|
||||||
|
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||||
|
|
||||||
|
jit_generate = jit(model.generate)
|
||||||
|
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
|
||||||
|
|
||||||
|
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||||
@@ -19,16 +19,16 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import GPT2Config, is_flax_available, is_torch_available
|
from transformers import GPT2Config, GPT2Tokenizer, is_flax_available, is_torch_available
|
||||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
|
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
|
||||||
|
|
||||||
|
from .test_generation_flax_utils import FlaxGenerationTesterMixin
|
||||||
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
||||||
|
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from jax import lax
|
|
||||||
from transformers.modeling_flax_pytorch_utils import (
|
from transformers.modeling_flax_pytorch_utils import (
|
||||||
convert_pytorch_state_dict_to_flax,
|
convert_pytorch_state_dict_to_flax,
|
||||||
load_flax_weights_in_pytorch_model,
|
load_flax_weights_in_pytorch_model,
|
||||||
@@ -116,8 +116,25 @@ class FlaxGPT2ModelTester:
|
|||||||
model = model_class_name(config)
|
model = model_class_name(config)
|
||||||
|
|
||||||
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
|
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
|
||||||
outputs_cache = model(input_ids[:, :-1], past_key_values=past_key_values)
|
attention_mask = jnp.ones((input_ids.shape[0], max_decoder_length), dtype="i4")
|
||||||
outputs_cache_next = model(input_ids[:, -1:], past_key_values=outputs_cache.past_key_values)
|
|
||||||
|
position_ids = jnp.broadcast_to(
|
||||||
|
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
|
||||||
|
)
|
||||||
|
outputs_cache = model(
|
||||||
|
input_ids[:, :-1],
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
position_ids=position_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
|
||||||
|
outputs_cache_next = model(
|
||||||
|
input_ids[:, -1:],
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=outputs_cache.past_key_values,
|
||||||
|
position_ids=position_ids,
|
||||||
|
)
|
||||||
|
|
||||||
outputs = model(input_ids)
|
outputs = model(input_ids)
|
||||||
|
|
||||||
@@ -134,10 +151,22 @@ class FlaxGPT2ModelTester:
|
|||||||
)
|
)
|
||||||
|
|
||||||
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
|
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
|
||||||
|
position_ids = jnp.broadcast_to(
|
||||||
|
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
|
||||||
|
)
|
||||||
|
|
||||||
outputs_cache = model(input_ids[:, :-1], attention_mask=attention_mask_cache, past_key_values=past_key_values)
|
outputs_cache = model(
|
||||||
|
input_ids[:, :-1],
|
||||||
|
attention_mask=attention_mask_cache,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
position_ids=position_ids,
|
||||||
|
)
|
||||||
|
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
|
||||||
outputs_cache_next = model(
|
outputs_cache_next = model(
|
||||||
input_ids[:, -1:], past_key_values=outputs_cache.past_key_values, attention_mask=attention_mask_cache
|
input_ids[:, -1:],
|
||||||
|
past_key_values=outputs_cache.past_key_values,
|
||||||
|
attention_mask=attention_mask_cache,
|
||||||
|
position_ids=position_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = model(input_ids, attention_mask=attention_mask)
|
outputs = model(input_ids, attention_mask=attention_mask)
|
||||||
@@ -145,66 +174,12 @@ class FlaxGPT2ModelTester:
|
|||||||
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
|
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
|
||||||
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
||||||
|
|
||||||
def check_use_cache_generation(self, config, input_ids):
|
|
||||||
prompt_length = 3
|
|
||||||
model = FlaxGPT2LMHeadModel(config)
|
|
||||||
max_length = 10
|
|
||||||
batch_size = 1
|
|
||||||
|
|
||||||
prompt_ids = input_ids[:1, :prompt_length]
|
|
||||||
|
|
||||||
# put all generation logic into one function
|
|
||||||
def generate(prompt_ids):
|
|
||||||
def first_pass(prompt_ids):
|
|
||||||
logits, cache = model(prompt_ids, past_key_values=past_key_values)[:2]
|
|
||||||
next_token = jnp.argmax(logits[:, -1:], axis=-1)
|
|
||||||
return next_token, cache
|
|
||||||
|
|
||||||
def greedy_search_cond_fn(state):
|
|
||||||
cur_len, _, _, _ = state
|
|
||||||
return ~(cur_len == max_length - 1)
|
|
||||||
|
|
||||||
def greedy_search_body_fn(state):
|
|
||||||
cur_len, sequences, current_token, cache = state
|
|
||||||
next_sequences = lax.dynamic_update_slice(sequences, current_token, (0, cur_len))
|
|
||||||
|
|
||||||
next_logits, next_cache = model(current_token, past_key_values=cache)[:2]
|
|
||||||
next_token = jnp.argmax(next_logits, axis=-1)
|
|
||||||
|
|
||||||
return cur_len + 1, next_sequences, next_token, next_cache
|
|
||||||
|
|
||||||
# init tensor to be filled with generation result
|
|
||||||
init_sequences = jnp.zeros((batch_size, max_length), dtype="i4")
|
|
||||||
init_sequences = lax.dynamic_update_slice(init_sequences, prompt_ids, (0, 0))
|
|
||||||
|
|
||||||
# init past key values for cache
|
|
||||||
past_key_values = model.init_cache(batch_size, max_length)
|
|
||||||
|
|
||||||
# first pass with long prompt
|
|
||||||
next_token, cache = first_pass(prompt_ids)
|
|
||||||
|
|
||||||
# prepare state for generation loop
|
|
||||||
init_state = (jnp.array(prompt_length), init_sequences, next_token, cache)
|
|
||||||
|
|
||||||
# fast generation
|
|
||||||
_, output_sequences, final_token, _ = lax.while_loop(
|
|
||||||
greedy_search_cond_fn, greedy_search_body_fn, init_state
|
|
||||||
)
|
|
||||||
|
|
||||||
# append last token
|
|
||||||
output_sequences = lax.dynamic_update_slice(output_sequences, final_token, (0, max_length - 1))
|
|
||||||
|
|
||||||
return output_sequences
|
|
||||||
|
|
||||||
jit_generate = jax.jit(generate)
|
|
||||||
output_sequences = jit_generate(prompt_ids)
|
|
||||||
self.parent.assertEqual(output_sequences.shape, (1, max_length))
|
|
||||||
|
|
||||||
|
|
||||||
@require_flax
|
@require_flax
|
||||||
class FlaxGPT2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
all_model_classes = (FlaxGPT2Model, FlaxGPT2LMHeadModel) if is_flax_available() else ()
|
all_model_classes = (FlaxGPT2Model, FlaxGPT2LMHeadModel) if is_flax_available() else ()
|
||||||
|
all_generative_model_classes = (FlaxGPT2LMHeadModel,) if is_flax_available() else ()
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = FlaxGPT2ModelTester(self)
|
self.model_tester = FlaxGPT2ModelTester(self)
|
||||||
@@ -221,9 +196,27 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
model_class_name, config, input_ids, attention_mask
|
model_class_name, config, input_ids, attention_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_use_cache_generation(self):
|
@slow
|
||||||
config, input_ids, _ = self.model_tester.prepare_config_and_inputs()
|
def test_batch_generation(self):
|
||||||
self.model_tester.check_use_cache_generation(config, input_ids)
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="</s>", padding_side="left")
|
||||||
|
inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="jax", padding=True, truncation=True)
|
||||||
|
|
||||||
|
model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")
|
||||||
|
model.do_sample = False
|
||||||
|
model.config.pad_token_id = model.config.eos_token_id
|
||||||
|
|
||||||
|
jit_generate = jax.jit(model.generate)
|
||||||
|
|
||||||
|
output_sequences = jit_generate(inputs["input_ids"], attention_mask=inputs["attention_mask"]).sequences
|
||||||
|
|
||||||
|
output_string = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
|
||||||
|
|
||||||
|
expected_string = [
|
||||||
|
"Hello this is a long string of words. I'm going to try to explain what I mean.",
|
||||||
|
"Hey, I'm not sure if I'm going to be able to do",
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertListEqual(output_string, expected_string)
|
||||||
|
|
||||||
# overwrite from common since `attention_mask` in combination
|
# overwrite from common since `attention_mask` in combination
|
||||||
# with `causal_mask` behaves slighly differently
|
# with `causal_mask` behaves slighly differently
|
||||||
|
|||||||
Reference in New Issue
Block a user