Flax Generate (#11777)
* fix_torch_device_generate_test * remove @ * add * indexing * correct a couple of tests * fix tests * add logits processor * finish top_k, top_p, temp * add docs * correct flax prng key default * improve generate * add generation docs * add docs * make style * revert model outputs change * make style * correct typo * fix tests * fix slow test * add raise * finish generation Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
2df546918e
commit
996a315e76
@@ -78,6 +78,9 @@ GreedySearchOutput
|
||||
.. autoclass:: transformers.generation_utils.GreedySearchEncoderDecoderOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.generation_flax_utils.FlaxGreedySearchOutput
|
||||
:members:
|
||||
|
||||
|
||||
SampleOutput
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
@@ -88,6 +91,9 @@ SampleOutput
|
||||
.. autoclass:: transformers.generation_utils.SampleEncoderDecoderOutput
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.generation_flax_utils.FlaxSampleOutput
|
||||
:members:
|
||||
|
||||
|
||||
BeamSearchOutput
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
@@ -160,6 +166,24 @@ generation.
|
||||
.. autoclass:: transformers.InfNanRemoveLogitsProcessor
|
||||
:members: __call__
|
||||
|
||||
.. autoclass:: transformers.FlaxLogitsProcessor
|
||||
:members: __call__
|
||||
|
||||
.. autoclass:: transformers.FlaxLogitsProcessorList
|
||||
:members: __call__
|
||||
|
||||
.. autoclass:: transformers.FlaxLogitsWarper
|
||||
:members: __call__
|
||||
|
||||
.. autoclass:: transformers.FlaxTemperatureLogitsWarper
|
||||
:members: __call__
|
||||
|
||||
.. autoclass:: transformers.FlaxTopPLogitsWarper
|
||||
:members: __call__
|
||||
|
||||
.. autoclass:: transformers.FlaxTopKLogitsWarper
|
||||
:members: __call__
|
||||
|
||||
|
||||
StoppingCriteria
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -26,8 +26,9 @@ are common among all the models to:
|
||||
|
||||
The other methods that are common to each model are defined in :class:`~transformers.modeling_utils.ModuleUtilsMixin`
|
||||
(for the PyTorch models) and :class:`~transformers.modeling_tf_utils.TFModuleUtilsMixin` (for the TensorFlow models) or
|
||||
for text generation, :class:`~transformers.generation_utils.GenerationMixin` (for the PyTorch models) and
|
||||
:class:`~transformers.generation_tf_utils.TFGenerationMixin` (for the TensorFlow models)
|
||||
for text generation, :class:`~transformers.generation_utils.GenerationMixin` (for the PyTorch models),
|
||||
:class:`~transformers.generation_tf_utils.TFGenerationMixin` (for the TensorFlow models) and
|
||||
:class:`~transformers.generation_flax_utils.FlaxGenerationMixin` (for the Flax/JAX models).
|
||||
|
||||
|
||||
PreTrainedModel
|
||||
@@ -74,6 +75,9 @@ Generation
|
||||
.. autoclass:: transformers.generation_tf_utils.TFGenerationMixin
|
||||
:members:
|
||||
|
||||
.. autoclass:: transformers.generation_flax_utils.FlaxGenerationMixin
|
||||
:members:
|
||||
|
||||
|
||||
Pushing to the Hub
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -1437,6 +1437,14 @@ else:
|
||||
|
||||
# FLAX-backed objects
|
||||
if is_flax_available():
|
||||
_import_structure["generation_flax_logits_process"] = [
|
||||
"FlaxLogitsProcessor",
|
||||
"FlaxLogitsProcessorList",
|
||||
"FlaxLogitsWarper",
|
||||
"FlaxTemperatureLogitsWarper",
|
||||
"FlaxTopKLogitsWarper",
|
||||
"FlaxTopPLogitsWarper",
|
||||
]
|
||||
_import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
|
||||
_import_structure["models.auto"].extend(
|
||||
[
|
||||
@@ -2693,6 +2701,14 @@ if TYPE_CHECKING:
|
||||
from .utils.dummy_tf_objects import *
|
||||
|
||||
if is_flax_available():
|
||||
from .generation_flax_logits_process import (
|
||||
FlaxLogitsProcessor,
|
||||
FlaxLogitsProcessorList,
|
||||
FlaxLogitsWarper,
|
||||
FlaxTemperatureLogitsWarper,
|
||||
FlaxTopKLogitsWarper,
|
||||
FlaxTopPLogitsWarper,
|
||||
)
|
||||
from .modeling_flax_utils import FlaxPreTrainedModel
|
||||
from .models.auto import (
|
||||
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
|
||||
192
src/transformers/generation_flax_logits_process.py
Normal file
192
src/transformers/generation_flax_logits_process.py
Normal file
@@ -0,0 +1,192 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from abc import ABC
|
||||
|
||||
import jax
|
||||
import jax.lax as lax
|
||||
import jax.numpy as jnp
|
||||
import jaxlib.xla_extension as jax_xla
|
||||
|
||||
from .file_utils import add_start_docstrings
|
||||
from .utils.logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
|
||||
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
||||
details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
scores (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.vocab_size)`):
|
||||
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
|
||||
search or log softmax for each vocabulary token when using beam search
|
||||
kwargs:
|
||||
Additional logits processor specific kwargs.
|
||||
|
||||
Return:
|
||||
:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class FlaxLogitsProcessor(ABC):
|
||||
"""Abstract base class for all logit processors that can be applied during generation."""
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
|
||||
"""Flax method for processing logits."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
||||
)
|
||||
|
||||
|
||||
class FlaxLogitsWarper(ABC):
|
||||
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
|
||||
"""Flax method for warping logits."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
|
||||
)
|
||||
|
||||
|
||||
class FlaxLogitsProcessorList(list):
|
||||
"""
|
||||
This class can be used to create a list of :class:`~transformers.FlaxLogitsProcessor` or
|
||||
:class:`~transformers.FlaxLogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits
|
||||
from list and adds a specific `__call__` method to apply each :class:`~transformers.FlaxLogitsProcessor` or
|
||||
:class:`~transformers.FlaxLogitsWarper` to the inputs.
|
||||
"""
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, **kwargs) -> jax_xla.DeviceArray:
|
||||
for processor in self:
|
||||
function_args = inspect.signature(processor.__call__).parameters
|
||||
if len(function_args) > 2:
|
||||
assert all(
|
||||
arg in kwargs for arg in list(function_args.keys())[2:]
|
||||
), f"Make sure that all the required parameters: {list(function_args.keys())} for {processor.__class__} are passed to the logits processor."
|
||||
scores = processor(input_ids, scores, **kwargs)
|
||||
else:
|
||||
scores = processor(input_ids, scores)
|
||||
return scores
|
||||
|
||||
|
||||
class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
|
||||
r"""
|
||||
:class:`transformers.LogitsWarper` for temperature (exponential scaling output probability distribution).
|
||||
|
||||
Args:
|
||||
temperature (:obj:`float`):
|
||||
The value used to module the logits distribution.
|
||||
"""
|
||||
|
||||
def __init__(self, temperature: float):
|
||||
if not isinstance(temperature, float) or not (temperature > 0):
|
||||
raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
|
||||
|
||||
self.temperature = temperature
|
||||
|
||||
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
|
||||
scores = scores / self.temperature
|
||||
return scores
|
||||
|
||||
|
||||
class FlaxTopPLogitsWarper(FlaxLogitsWarper):
|
||||
"""
|
||||
:class:`transformers.LogitsWarper` that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <=
|
||||
prob_cut_off.
|
||||
|
||||
Args:
|
||||
top_p (:obj:`float`):
|
||||
If set to < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are
|
||||
kept for generation.
|
||||
filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`):
|
||||
All filtered values will be set to this float value.
|
||||
min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1):
|
||||
Minimum number of tokens that cannot be filtered.
|
||||
"""
|
||||
|
||||
def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
|
||||
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
|
||||
|
||||
self.top_p = top_p
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
|
||||
topk_scores, topk_indices = lax.top_k(scores, scores.shape[-1])
|
||||
|
||||
mask_scores = jnp.full_like(scores, self.filter_value)
|
||||
cumulative_probs = jax.nn.softmax(topk_scores, axis=-1).cumsum(axis=-1)
|
||||
score_mask = cumulative_probs < self.top_p
|
||||
|
||||
# include the token that is higher than top_p as well
|
||||
score_mask |= jax.ops.index_update(jnp.roll(score_mask, 1), jax.ops.index[:, 0], True)
|
||||
|
||||
# min tokens to keep
|
||||
score_mask = jax.ops.index_update(score_mask, jax.ops.index[:, : self.min_tokens_to_keep], True)
|
||||
|
||||
topk_next_scores = jnp.where(score_mask, topk_scores, mask_scores)
|
||||
next_scores = jax.lax.sort_key_val(topk_indices, topk_next_scores)[-1]
|
||||
|
||||
return next_scores
|
||||
|
||||
|
||||
class FlaxTopKLogitsWarper(FlaxLogitsWarper):
|
||||
r"""
|
||||
:class:`transformers.LogitsWarper` that performs top-k, i.e. restricting to the k highest probability elements.
|
||||
|
||||
Args:
|
||||
top_k (:obj:`int`):
|
||||
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||
filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`):
|
||||
All filtered values will be set to this float value.
|
||||
min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1):
|
||||
Minimum number of tokens that cannot be filtered.
|
||||
"""
|
||||
|
||||
def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
|
||||
|
||||
self.top_k = top_k
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
|
||||
batch_size, vocab_size = scores.shape
|
||||
next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value)
|
||||
|
||||
topk = min(max(self.top_k, self.min_tokens_to_keep), scores.shape[-1]) # Safety check
|
||||
topk_scores, topk_indices = lax.top_k(scores, topk)
|
||||
shift = jnp.broadcast_to((jnp.arange(batch_size) * vocab_size)[:, None], (batch_size, topk)).flatten()
|
||||
topk_scores_flat = topk_scores.flatten()
|
||||
topk_indices_flat = topk_indices.flatten() + shift
|
||||
|
||||
next_scores_flat = jax.ops.index_update(next_scores_flat, topk_indices_flat, topk_scores_flat)
|
||||
next_scores = next_scores_flat.reshape(batch_size, vocab_size)
|
||||
return next_scores
|
||||
388
src/transformers/generation_flax_utils.py
Normal file
388
src/transformers/generation_flax_utils.py
Normal file
@@ -0,0 +1,388 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The Google AI Flax Team Authors, and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
import flax
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jaxlib.xla_extension as jax_xla
|
||||
from jax import lax
|
||||
|
||||
from .file_utils import ModelOutput
|
||||
from .generation_flax_logits_process import (
|
||||
FlaxLogitsProcessorList,
|
||||
FlaxTemperatureLogitsWarper,
|
||||
FlaxTopKLogitsWarper,
|
||||
FlaxTopPLogitsWarper,
|
||||
)
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxGreedySearchOutput(ModelOutput):
|
||||
"""
|
||||
Flax Base class for outputs of decoder-only generation models using greedy search.
|
||||
|
||||
|
||||
Args:
|
||||
sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
||||
The generated sequences. If all batches finished early due to the :obj:`eos_token_id`, :obj:`sequences` is
|
||||
padded to :obj:`max_length`.
|
||||
"""
|
||||
|
||||
sequences: jax_xla.DeviceArray = None
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class FlaxSampleOutput(ModelOutput):
|
||||
"""
|
||||
Flax Base class for outputs of decoder-only generation models using sampling.
|
||||
|
||||
|
||||
Args:
|
||||
sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size, max_length)`):
|
||||
The generated sequences. If all batches finished early due to the :obj:`eos_token_id`, :obj:`sequences` is
|
||||
padded to :obj:`max_length`.
|
||||
"""
|
||||
|
||||
sequences: jax_xla.DeviceArray = None
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class GreedyState:
|
||||
cur_len: jax_xla.DeviceArray
|
||||
sequences: jax_xla.DeviceArray
|
||||
current_token: jax_xla.DeviceArray
|
||||
is_sent_finished: jax_xla.DeviceArray
|
||||
model_kwargs: Dict[str, jax_xla.DeviceArray]
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class SampleState:
|
||||
cur_len: jax_xla.DeviceArray
|
||||
sequences: jax_xla.DeviceArray
|
||||
current_token: jax_xla.DeviceArray
|
||||
is_sent_finished: jax_xla.DeviceArray
|
||||
prng_key: jax_xla.DeviceArray
|
||||
model_kwargs: Dict[str, jax_xla.DeviceArray]
|
||||
|
||||
|
||||
class FlaxGenerationMixin:
|
||||
"""
|
||||
A class containing all of the functions supporting generation, to be used as a mixin in
|
||||
:class:`~transformers.FlaxPreTrainedModel`.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _run_loop_in_debug(cond_fn, body_fn, init_state):
|
||||
"""
|
||||
Run generation in untraced mode. This should only be used for debugging purposes.
|
||||
"""
|
||||
state = init_state
|
||||
while cond_fn(state):
|
||||
state = body_fn(state)
|
||||
return state
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_ids: jax_xla.DeviceArray,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
do_sample: Optional[bool] = None,
|
||||
prng_key: Optional[jax_xla.DeviceArray] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
trace: bool = True,
|
||||
**model_kwargs,
|
||||
):
|
||||
r"""
|
||||
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
|
||||
and, multinomial sampling.
|
||||
|
||||
Apart from :obj:`input_ids`, all the arguments below will default to the value of the attribute of the same
|
||||
name inside the :class:`~transformers.PretrainedConfig` of the model. The default values indicated are the
|
||||
default values of those config.
|
||||
|
||||
Most of these parameters are explained in more detail in `this blog post
|
||||
<https://huggingface.co/blog/how-to-generate>`__.
|
||||
|
||||
Parameters:
|
||||
|
||||
input_ids (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
The sequence used as a prompt for the generation.
|
||||
max_length (:obj:`int`, `optional`, defaults to 20):
|
||||
The maximum length of the sequence to be generated.
|
||||
do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to use sampling ; use greedy decoding otherwise.
|
||||
temperature (:obj:`float`, `optional`, defaults to 1.0):
|
||||
The value used to module the next token probabilities.
|
||||
top_k (:obj:`int`, `optional`, defaults to 50):
|
||||
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
||||
top_p (:obj:`float`, `optional`, defaults to 1.0):
|
||||
If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or
|
||||
higher are kept for generation.
|
||||
pad_token_id (:obj:`int`, `optional`):
|
||||
The id of the `padding` token.
|
||||
bos_token_id (:obj:`int`, `optional`):
|
||||
The id of the `beginning-of-sequence` token.
|
||||
eos_token_id (:obj:`int`, `optional`):
|
||||
The id of the `end-of-sequence` token.
|
||||
trace (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to
|
||||
a considerably slower runtime.
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
|
||||
|
||||
Return:
|
||||
:class:`~transformers.file_utils.ModelOutput`.
|
||||
|
||||
Examples::
|
||||
>>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
|
||||
>>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2")
|
||||
>>> input_context = "The dog"
|
||||
>>> # encode input context
|
||||
>>> input_ids = tokenizer(input_context, return_tensors="jax").input_ids
|
||||
>>> # generate candidates using sampling
|
||||
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
|
||||
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
||||
"""
|
||||
# set init values
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
|
||||
|
||||
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||
|
||||
if do_sample:
|
||||
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature)
|
||||
return self._sample(
|
||||
input_ids,
|
||||
max_length,
|
||||
pad_token_id,
|
||||
eos_token_id,
|
||||
prng_key,
|
||||
logits_warper=logits_warper,
|
||||
model_kwargs=model_kwargs,
|
||||
trace=trace,
|
||||
)
|
||||
else:
|
||||
return self._greedy_search(
|
||||
input_ids, max_length, pad_token_id, eos_token_id, trace=trace, model_kwargs=model_kwargs
|
||||
)
|
||||
|
||||
def _get_logits_warper(
|
||||
self, top_k: int = None, top_p: float = None, temperature: float = None
|
||||
) -> FlaxLogitsProcessorList:
|
||||
"""
|
||||
This class returns a :obj:`~transformers.FlaxLogitsProcessorList` list object that contains all relevant
|
||||
:obj:`~transformers.FlaxLogitsWarper` instances used for multinomial sampling.
|
||||
"""
|
||||
|
||||
# init warp parameters
|
||||
top_k = top_k if top_k is not None else self.config.top_k
|
||||
top_p = top_p if top_p is not None else self.config.top_p
|
||||
temperature = temperature if temperature is not None else self.config.temperature
|
||||
# instantiate warpers list
|
||||
warpers = FlaxLogitsProcessorList()
|
||||
|
||||
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||
# all samplers can be found in `generation_utils_samplers.py`
|
||||
if temperature is not None and temperature != 1.0:
|
||||
warpers.append(FlaxTemperatureLogitsWarper(temperature))
|
||||
if top_k is not None and top_k != 0:
|
||||
warpers.append(FlaxTopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1))
|
||||
if top_p is not None and top_p < 1.0:
|
||||
warpers.append(FlaxTopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1))
|
||||
|
||||
return warpers
|
||||
|
||||
def _greedy_search(
|
||||
self,
|
||||
input_ids: None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
trace: bool = True,
|
||||
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
||||
):
|
||||
# init values
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
|
||||
batch_size, cur_len = input_ids.shape
|
||||
|
||||
eos_token_id = jnp.array(eos_token_id)
|
||||
pad_token_id = jnp.array(pad_token_id)
|
||||
cur_len = jnp.array(cur_len)
|
||||
|
||||
# per batch-item holding current token in loop.
|
||||
sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
|
||||
sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
|
||||
|
||||
# per batch-item state bit indicating if sentence has finished.
|
||||
is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
|
||||
|
||||
model = self
|
||||
|
||||
# initialize model specific kwargs
|
||||
model_kwargs = model.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
|
||||
|
||||
# initialize state
|
||||
state = GreedyState(
|
||||
cur_len=cur_len,
|
||||
sequences=sequences,
|
||||
current_token=input_ids,
|
||||
is_sent_finished=is_sent_finished,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
|
||||
def greedy_search_cond_fn(state):
|
||||
"""state termination condition fn."""
|
||||
has_reached_max_length = state.cur_len == max_length
|
||||
all_sequence_finished = jnp.all(state.is_sent_finished)
|
||||
finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished)
|
||||
return ~finish_generation
|
||||
|
||||
def greedy_search_body_fn(state):
|
||||
"""state update fn."""
|
||||
model_outputs = model(state.current_token, **state.model_kwargs)
|
||||
next_token = jnp.argmax(model_outputs.logits[:, -1], axis=-1)
|
||||
|
||||
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
|
||||
next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished
|
||||
next_token = next_token[:, None]
|
||||
|
||||
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
|
||||
next_model_kwargs = model.update_inputs_for_generation(model_outputs, model_kwargs)
|
||||
|
||||
return GreedyState(
|
||||
cur_len=state.cur_len + 1,
|
||||
sequences=next_sequences,
|
||||
current_token=next_token,
|
||||
is_sent_finished=next_is_sent_finished,
|
||||
model_kwargs=next_model_kwargs,
|
||||
)
|
||||
|
||||
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
|
||||
state = greedy_search_body_fn(state)
|
||||
|
||||
if not trace:
|
||||
state = self._run_loop_in_debug(greedy_search_cond_fn, greedy_search_body_fn, state)
|
||||
else:
|
||||
state = lax.while_loop(greedy_search_cond_fn, greedy_search_body_fn, state)
|
||||
|
||||
return FlaxGreedySearchOutput(sequences=state.sequences)
|
||||
|
||||
def _sample(
|
||||
self,
|
||||
input_ids: None,
|
||||
max_length: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
prng_key: Optional[jax_xla.DeviceArray] = None,
|
||||
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
||||
logits_warper: Optional[FlaxLogitsProcessorList] = None,
|
||||
trace: bool = True,
|
||||
):
|
||||
# init values
|
||||
max_length = max_length if max_length is not None else self.config.max_length
|
||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
|
||||
|
||||
batch_size, cur_len = input_ids.shape
|
||||
|
||||
eos_token_id = jnp.array(eos_token_id)
|
||||
pad_token_id = jnp.array(pad_token_id)
|
||||
cur_len = jnp.array(cur_len)
|
||||
|
||||
# per batch-item holding current token in loop.
|
||||
sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
|
||||
sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
|
||||
|
||||
# per batch-item state bit indicating if sentence has finished.
|
||||
is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
|
||||
|
||||
model = self
|
||||
|
||||
# initialize model specific kwargs
|
||||
model_kwargs = model.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
|
||||
|
||||
# initialize state
|
||||
state = SampleState(
|
||||
cur_len=cur_len,
|
||||
sequences=sequences,
|
||||
current_token=input_ids,
|
||||
is_sent_finished=is_sent_finished,
|
||||
prng_key=prng_key,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
|
||||
def sample_search_cond_fn(state):
|
||||
"""state termination condition fn."""
|
||||
has_reached_max_length = state.cur_len == max_length
|
||||
all_sequence_finished = jnp.all(state.is_sent_finished)
|
||||
finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished)
|
||||
return ~finish_generation
|
||||
|
||||
def sample_search_body_fn(state):
|
||||
"""state update fn."""
|
||||
prng_key, prng_key_next = jax.random.split(state.prng_key)
|
||||
model_outputs = model(state.current_token, **state.model_kwargs)
|
||||
|
||||
logits = model_outputs.logits[:, -1]
|
||||
|
||||
# apply top_k, top_k, temperature
|
||||
logits = logits_warper(state.sequences, logits)
|
||||
|
||||
next_token = jax.random.categorical(prng_key, model_outputs.logits[:, -1], axis=-1)
|
||||
|
||||
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
|
||||
next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished
|
||||
next_token = next_token[:, None]
|
||||
|
||||
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
|
||||
next_model_kwargs = model.update_inputs_for_generation(model_outputs, model_kwargs)
|
||||
|
||||
return SampleState(
|
||||
cur_len=state.cur_len + 1,
|
||||
sequences=next_sequences,
|
||||
current_token=next_token,
|
||||
is_sent_finished=next_is_sent_finished,
|
||||
model_kwargs=next_model_kwargs,
|
||||
prng_key=prng_key_next,
|
||||
)
|
||||
|
||||
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
|
||||
state = sample_search_body_fn(state)
|
||||
|
||||
if not trace:
|
||||
state = self._run_loop_in_debug(sample_search_cond_fn, sample_search_body_fn, state)
|
||||
else:
|
||||
state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)
|
||||
|
||||
return FlaxSampleOutput(sequences=state.sequences)
|
||||
@@ -41,6 +41,7 @@ from .file_utils import (
|
||||
is_remote_url,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .generation_flax_utils import FlaxGenerationMixin
|
||||
from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
|
||||
from .utils import logging
|
||||
|
||||
@@ -57,7 +58,7 @@ ACT2FN = {
|
||||
}
|
||||
|
||||
|
||||
class FlaxPreTrainedModel(PushToHubMixin):
|
||||
class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
|
||||
r"""
|
||||
Base class for all models.
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||
from flax.linen import combine_masks, dot_product_attention, make_causal_mask
|
||||
from flax.traverse_util import flatten_dict
|
||||
from jax import lax
|
||||
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
@@ -322,13 +321,6 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
|
||||
|
||||
@property
|
||||
def _attn_layer_name(self):
|
||||
attn_layer_key_tuple = ("h", "0", "attn")
|
||||
if self.base_model_prefix in set(self.params.keys()):
|
||||
attn_layer_key_tuple = (self.base_model_prefix,) + attn_layer_key_tuple
|
||||
return attn_layer_key_tuple
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
@@ -381,28 +373,13 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
|
||||
if position_ids is None:
|
||||
if past_key_values is not None and input_ids.shape[-1] == 1:
|
||||
# if `past_key_values` are passed and input_ids are longer than 1, we are in cached auto-regressive generation. It has to be made sure that position_ids are set correctly
|
||||
cache_shift = flatten_dict(unfreeze(past_key_values))[self._attn_layer_name + ("cache_index",)]
|
||||
position_ids = jnp.broadcast_to(
|
||||
jnp.arange(self.config.max_position_embeddings)[None, :],
|
||||
(batch_size, self.config.max_position_embeddings),
|
||||
)
|
||||
position_ids = lax.dynamic_slice(position_ids, (0, cache_shift), (batch_size, 1))
|
||||
else:
|
||||
if past_key_values is not None:
|
||||
raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
|
||||
|
||||
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
|
||||
|
||||
if attention_mask is None:
|
||||
# if past_key_values are passed we need to create an attention_mask of the same length as `cache_length`
|
||||
if past_key_values is not None:
|
||||
cache_length = flatten_dict(unfreeze(past_key_values))[self._attn_layer_name + ("cached_key",)].shape[
|
||||
1
|
||||
]
|
||||
else:
|
||||
cache_length = sequence_length
|
||||
|
||||
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. But since GPT2 uses a causal mask, those positions are masked anyways. Thus we can create a single static attention_mask here, which is more efficient for compilation
|
||||
attention_mask = jnp.ones((batch_size, cache_length))
|
||||
attention_mask = jnp.ones((batch_size, sequence_length))
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
@@ -627,6 +604,32 @@ class FlaxGPT2LMHeadModule(nn.Module):
|
||||
class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel):
|
||||
module_class = FlaxGPT2LMHeadModule
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
|
||||
# initializing the cache
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
past_key_values = self.init_cache(batch_size, max_length)
|
||||
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
||||
# But since GPT2 uses a causal mask, those positions are masked anyways.
|
||||
# Thus we can create a single static attention_mask here, which is more efficient for compilation
|
||||
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
||||
if attention_mask is not None:
|
||||
position_ids = attention_mask.cumsum(axis=-1) - 1
|
||||
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
|
||||
else:
|
||||
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
|
||||
|
||||
return {
|
||||
"past_key_values": past_key_values,
|
||||
"attention_mask": extended_attention_mask,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
|
||||
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
||||
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
||||
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
|
||||
return model_kwargs
|
||||
|
||||
|
||||
append_call_sample_docstring(
|
||||
FlaxGPT2LMHeadModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC
|
||||
|
||||
@@ -2,6 +2,36 @@
|
||||
from ..file_utils import requires_backends
|
||||
|
||||
|
||||
class FlaxLogitsProcessor:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxLogitsProcessorList:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxLogitsWarper:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxTemperatureLogitsWarper:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxTopKLogitsWarper:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxTopPLogitsWarper:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
|
||||
class FlaxPreTrainedModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["flax"])
|
||||
|
||||
163
tests/test_generation_flax_logits_process.py
Normal file
163
tests/test_generation_flax_logits_process.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Team Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a clone of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import is_flax_available
|
||||
from transformers.testing_utils import require_flax
|
||||
|
||||
from .test_modeling_flax_common import ids_tensor
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from transformers.generation_flax_logits_process import (
|
||||
FlaxLogitsProcessorList,
|
||||
FlaxTemperatureLogitsWarper,
|
||||
FlaxTopKLogitsWarper,
|
||||
FlaxTopPLogitsWarper,
|
||||
)
|
||||
|
||||
|
||||
@require_flax
|
||||
class LogitsProcessorTest(unittest.TestCase):
|
||||
def _get_uniform_logits(self, batch_size: int, length: int):
|
||||
scores = np.ones((batch_size, length)) / length
|
||||
return scores
|
||||
|
||||
def test_temperature_dist_warper(self):
|
||||
input_ids = None
|
||||
length = 20
|
||||
|
||||
scores = self._get_uniform_logits(batch_size=2, length=length)
|
||||
|
||||
# tweak scores to not be uniform anymore
|
||||
scores[1, 5] = (1 / length) + 0.1 # peak, 1st batch
|
||||
scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch
|
||||
|
||||
# compute softmax
|
||||
probs = jax.nn.softmax(scores, axis=-1)
|
||||
|
||||
temp_dist_warper_sharper = FlaxTemperatureLogitsWarper(temperature=0.5)
|
||||
temp_dist_warper_smoother = FlaxTemperatureLogitsWarper(temperature=1.3)
|
||||
|
||||
warped_prob_sharp = jax.nn.softmax(temp_dist_warper_sharper(input_ids, scores.copy()), axis=-1)
|
||||
warped_prob_smooth = jax.nn.softmax(temp_dist_warper_smoother(input_ids, scores.copy()), axis=-1)
|
||||
|
||||
# uniform distribution stays uniform
|
||||
self.assertTrue(jnp.allclose(probs[0, :], warped_prob_sharp[0, :], atol=1e-3))
|
||||
self.assertTrue(jnp.allclose(probs[0, :], warped_prob_smooth[0, :], atol=1e-3))
|
||||
|
||||
# sharp peaks get higher, valleys get lower
|
||||
self.assertLess(probs[1, :].max(), warped_prob_sharp[1, :].max())
|
||||
self.assertGreater(probs[1, :].min(), warped_prob_sharp[1, :].min())
|
||||
|
||||
# smooth peaks get lower, valleys get higher
|
||||
self.assertGreater(probs[1, :].max(), warped_prob_smooth[1, :].max())
|
||||
self.assertLess(probs[1, :].min(), warped_prob_smooth[1, :].min())
|
||||
|
||||
def test_top_k_dist_warper(self):
|
||||
input_ids = None
|
||||
vocab_size = 10
|
||||
batch_size = 2
|
||||
|
||||
# create ramp distribution
|
||||
ramp_logits = np.broadcast_to(np.arange(vocab_size)[None, :], (batch_size, vocab_size)).copy()
|
||||
ramp_logits[1:, : vocab_size // 2] = ramp_logits[1:, : vocab_size // 2] + vocab_size
|
||||
|
||||
top_k_warp = FlaxTopKLogitsWarper(3)
|
||||
|
||||
scores = top_k_warp(input_ids, ramp_logits)
|
||||
|
||||
# check that correct tokens are filtered
|
||||
self.assertListEqual(jnp.isinf(scores[0]).tolist(), 7 * [True] + 3 * [False])
|
||||
self.assertListEqual(jnp.isinf(scores[1]).tolist(), 2 * [True] + 3 * [False] + 5 * [True])
|
||||
|
||||
# check special case
|
||||
length = 5
|
||||
top_k_warp_safety_check = FlaxTopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3)
|
||||
|
||||
ramp_logits = np.broadcast_to(np.arange(length)[None, :], (batch_size, length)).copy()
|
||||
scores = top_k_warp_safety_check(input_ids, ramp_logits)
|
||||
|
||||
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
|
||||
self.assertListEqual((scores == 0.0).sum(axis=-1).tolist(), [2, 2])
|
||||
|
||||
def test_top_p_dist_warper(self):
|
||||
input_ids = None
|
||||
vocab_size = 10
|
||||
batch_size = 2
|
||||
|
||||
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
|
||||
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]]))
|
||||
|
||||
top_p_warp = FlaxTopPLogitsWarper(0.7)
|
||||
filtered_dist = np.exp(top_p_warp(input_ids, dist))
|
||||
|
||||
# dist should be filtered to keep min num values so that sum is >= 0.7
|
||||
# exp (-inf) => 0
|
||||
EXPECTED_FILTERED_DIST = np.array([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]])
|
||||
self.assertTrue(np.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
|
||||
|
||||
# check edge cases with negative and extreme logits
|
||||
ramp_logits = np.broadcast_to(np.arange(vocab_size)[None, :], (batch_size, vocab_size)).copy() - (
|
||||
vocab_size // 2
|
||||
)
|
||||
|
||||
# make ramp_logits more extreme
|
||||
ramp_logits[1] = ramp_logits[1] * 100.0
|
||||
|
||||
# make sure at least 2 tokens are kept
|
||||
top_p_warp = FlaxTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
|
||||
filtered_dist = top_p_warp(input_ids, ramp_logits)
|
||||
|
||||
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
|
||||
self.assertListEqual((filtered_dist != 0.0).sum(axis=-1).tolist(), [3, 2])
|
||||
|
||||
def test_processor_list(self):
|
||||
batch_size = 4
|
||||
sequence_length = 10
|
||||
vocab_size = 15
|
||||
|
||||
# dummy input_ids and scores
|
||||
input_ids = ids_tensor((batch_size, sequence_length), vocab_size)
|
||||
input_ids_comp = input_ids.copy()
|
||||
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores_comp = scores.copy()
|
||||
|
||||
# instantiate all dist processors
|
||||
temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5)
|
||||
top_k_warp = FlaxTopKLogitsWarper(3)
|
||||
top_p_warp = FlaxTopPLogitsWarper(0.8)
|
||||
|
||||
# no processor list
|
||||
scores = temp_dist_warp(input_ids, scores)
|
||||
scores = top_k_warp(input_ids, scores)
|
||||
scores = top_p_warp(input_ids, scores)
|
||||
|
||||
# with processor list
|
||||
processor = FlaxLogitsProcessorList([temp_dist_warp, top_k_warp, top_p_warp])
|
||||
scores_comp = processor(input_ids, scores_comp)
|
||||
|
||||
# scores should be equal
|
||||
self.assertTrue(jnp.allclose(scores, scores_comp, atol=1e-3))
|
||||
|
||||
# input_ids should never be changed
|
||||
self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())
|
||||
170
tests/test_generation_flax_utils.py
Normal file
170
tests/test_generation_flax_utils.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import is_flax_available
|
||||
from transformers.testing_utils import require_flax
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import os
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import jit
|
||||
|
||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||
|
||||
|
||||
def ids_tensor(shape, vocab_size, rng=None):
|
||||
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||
if rng is None:
|
||||
rng = random.Random()
|
||||
|
||||
total_dims = 1
|
||||
for dim in shape:
|
||||
total_dims *= dim
|
||||
|
||||
values = []
|
||||
for _ in range(total_dims):
|
||||
values.append(rng.randint(0, vocab_size - 1))
|
||||
|
||||
output = np.array(values, dtype=jnp.int32).reshape(shape)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def random_attention_mask(shape, rng=None):
|
||||
attn_mask = ids_tensor(shape, vocab_size=2, rng=rng)
|
||||
# make sure that at least one token is attended to for each batch
|
||||
attn_mask[:, -1] = 1
|
||||
return attn_mask
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxGenerationTesterMixin:
|
||||
model_tester = None
|
||||
all_generative_model_classes = ()
|
||||
|
||||
def _get_input_ids_and_config(self):
|
||||
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# cut to half length & take max batch_size 3
|
||||
max_batch_size = 2
|
||||
sequence_length = inputs["input_ids"].shape[-1] // 2
|
||||
input_ids = inputs["input_ids"][:max_batch_size, :sequence_length]
|
||||
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
attention_mask = attention_mask[:max_batch_size, :sequence_length]
|
||||
|
||||
# generate max 5 tokens
|
||||
max_length = input_ids.shape[-1] + 5
|
||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||
config.pad_token_id = config.eos_token_id
|
||||
return config, input_ids, attention_mask, max_length
|
||||
|
||||
def test_greedy_generate(self):
|
||||
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||
config.do_sample = False
|
||||
config.max_length = max_length
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
generation_outputs = model.generate(input_ids).sequences
|
||||
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||
|
||||
jit_generate = jit(model.generate)
|
||||
jit_generation_outputs = jit_generate(input_ids).sequences
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
def test_sample_generate(self):
|
||||
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||
config.do_sample = True
|
||||
config.max_length = max_length
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
generation_outputs = model.generate(input_ids).sequences
|
||||
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||
|
||||
jit_generate = jit(model.generate)
|
||||
jit_generation_outputs = jit_generate(input_ids).sequences
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
def test_sample_generate_logits_warper(self):
|
||||
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||
config.do_sample = True
|
||||
config.max_length = max_length
|
||||
config.temperature = 0.8
|
||||
config.top_k = 10
|
||||
config.top_p = 0.3
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
generation_outputs = model.generate(input_ids).sequences
|
||||
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||
|
||||
jit_generate = jit(model.generate)
|
||||
jit_generation_outputs = jit_generate(input_ids).sequences
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
def test_greedy_generate_attn_mask(self):
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# pad attention mask on the left
|
||||
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
|
||||
|
||||
config.do_sample = False
|
||||
config.max_length = max_length
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
generation_outputs = model.generate(input_ids, attention_mask=attention_mask).sequences
|
||||
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||
|
||||
jit_generate = jit(model.generate)
|
||||
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
def test_sample_generate_attn_mask(self):
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# pad attention mask on the left
|
||||
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
|
||||
|
||||
config.do_sample = True
|
||||
config.max_length = max_length
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
generation_outputs = model.generate(input_ids, attention_mask=attention_mask).sequences
|
||||
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||
|
||||
jit_generate = jit(model.generate)
|
||||
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
@@ -19,16 +19,16 @@ import unittest
|
||||
import numpy as np
|
||||
|
||||
import transformers
|
||||
from transformers import GPT2Config, is_flax_available, is_torch_available
|
||||
from transformers import GPT2Config, GPT2Tokenizer, is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
|
||||
|
||||
from .test_generation_flax_utils import FlaxGenerationTesterMixin
|
||||
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
@@ -116,8 +116,25 @@ class FlaxGPT2ModelTester:
|
||||
model = model_class_name(config)
|
||||
|
||||
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
|
||||
outputs_cache = model(input_ids[:, :-1], past_key_values=past_key_values)
|
||||
outputs_cache_next = model(input_ids[:, -1:], past_key_values=outputs_cache.past_key_values)
|
||||
attention_mask = jnp.ones((input_ids.shape[0], max_decoder_length), dtype="i4")
|
||||
|
||||
position_ids = jnp.broadcast_to(
|
||||
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
|
||||
)
|
||||
outputs_cache = model(
|
||||
input_ids[:, :-1],
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
|
||||
outputs_cache_next = model(
|
||||
input_ids[:, -1:],
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=outputs_cache.past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
outputs = model(input_ids)
|
||||
|
||||
@@ -134,10 +151,22 @@ class FlaxGPT2ModelTester:
|
||||
)
|
||||
|
||||
past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
|
||||
position_ids = jnp.broadcast_to(
|
||||
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
|
||||
)
|
||||
|
||||
outputs_cache = model(input_ids[:, :-1], attention_mask=attention_mask_cache, past_key_values=past_key_values)
|
||||
outputs_cache = model(
|
||||
input_ids[:, :-1],
|
||||
attention_mask=attention_mask_cache,
|
||||
past_key_values=past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4")
|
||||
outputs_cache_next = model(
|
||||
input_ids[:, -1:], past_key_values=outputs_cache.past_key_values, attention_mask=attention_mask_cache
|
||||
input_ids[:, -1:],
|
||||
past_key_values=outputs_cache.past_key_values,
|
||||
attention_mask=attention_mask_cache,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
outputs = model(input_ids, attention_mask=attention_mask)
|
||||
@@ -145,66 +174,12 @@ class FlaxGPT2ModelTester:
|
||||
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
|
||||
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
||||
|
||||
def check_use_cache_generation(self, config, input_ids):
|
||||
prompt_length = 3
|
||||
model = FlaxGPT2LMHeadModel(config)
|
||||
max_length = 10
|
||||
batch_size = 1
|
||||
|
||||
prompt_ids = input_ids[:1, :prompt_length]
|
||||
|
||||
# put all generation logic into one function
|
||||
def generate(prompt_ids):
|
||||
def first_pass(prompt_ids):
|
||||
logits, cache = model(prompt_ids, past_key_values=past_key_values)[:2]
|
||||
next_token = jnp.argmax(logits[:, -1:], axis=-1)
|
||||
return next_token, cache
|
||||
|
||||
def greedy_search_cond_fn(state):
|
||||
cur_len, _, _, _ = state
|
||||
return ~(cur_len == max_length - 1)
|
||||
|
||||
def greedy_search_body_fn(state):
|
||||
cur_len, sequences, current_token, cache = state
|
||||
next_sequences = lax.dynamic_update_slice(sequences, current_token, (0, cur_len))
|
||||
|
||||
next_logits, next_cache = model(current_token, past_key_values=cache)[:2]
|
||||
next_token = jnp.argmax(next_logits, axis=-1)
|
||||
|
||||
return cur_len + 1, next_sequences, next_token, next_cache
|
||||
|
||||
# init tensor to be filled with generation result
|
||||
init_sequences = jnp.zeros((batch_size, max_length), dtype="i4")
|
||||
init_sequences = lax.dynamic_update_slice(init_sequences, prompt_ids, (0, 0))
|
||||
|
||||
# init past key values for cache
|
||||
past_key_values = model.init_cache(batch_size, max_length)
|
||||
|
||||
# first pass with long prompt
|
||||
next_token, cache = first_pass(prompt_ids)
|
||||
|
||||
# prepare state for generation loop
|
||||
init_state = (jnp.array(prompt_length), init_sequences, next_token, cache)
|
||||
|
||||
# fast generation
|
||||
_, output_sequences, final_token, _ = lax.while_loop(
|
||||
greedy_search_cond_fn, greedy_search_body_fn, init_state
|
||||
)
|
||||
|
||||
# append last token
|
||||
output_sequences = lax.dynamic_update_slice(output_sequences, final_token, (0, max_length - 1))
|
||||
|
||||
return output_sequences
|
||||
|
||||
jit_generate = jax.jit(generate)
|
||||
output_sequences = jit_generate(prompt_ids)
|
||||
self.parent.assertEqual(output_sequences.shape, (1, max_length))
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxGPT2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (FlaxGPT2Model, FlaxGPT2LMHeadModel) if is_flax_available() else ()
|
||||
all_generative_model_classes = (FlaxGPT2LMHeadModel,) if is_flax_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FlaxGPT2ModelTester(self)
|
||||
@@ -221,9 +196,27 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
model_class_name, config, input_ids, attention_mask
|
||||
)
|
||||
|
||||
def test_use_cache_generation(self):
|
||||
config, input_ids, _ = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_use_cache_generation(config, input_ids)
|
||||
@slow
|
||||
def test_batch_generation(self):
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="</s>", padding_side="left")
|
||||
inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="jax", padding=True, truncation=True)
|
||||
|
||||
model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")
|
||||
model.do_sample = False
|
||||
model.config.pad_token_id = model.config.eos_token_id
|
||||
|
||||
jit_generate = jax.jit(model.generate)
|
||||
|
||||
output_sequences = jit_generate(inputs["input_ids"], attention_mask=inputs["attention_mask"]).sequences
|
||||
|
||||
output_string = tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
|
||||
|
||||
expected_string = [
|
||||
"Hello this is a long string of words. I'm going to try to explain what I mean.",
|
||||
"Hey, I'm not sure if I'm going to be able to do",
|
||||
]
|
||||
|
||||
self.assertListEqual(output_string, expected_string)
|
||||
|
||||
# overwrite from common since `attention_mask` in combination
|
||||
# with `causal_mask` behaves slighly differently
|
||||
|
||||
Reference in New Issue
Block a user