From 996a315e76f6c972c854990e6114226a91bc0a90 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 27 May 2021 00:18:17 +0100 Subject: [PATCH] Flax Generate (#11777) * fix_torch_device_generate_test * remove @ * add * indexing * correct a couple of tests * fix tests * add logits processor * finish top_k, top_p, temp * add docs * correct flax prng key default * improve generate * add generation docs * add docs * make style * revert model outputs change * make style * correct typo * fix tests * fix slow test * add raise * finish generation Co-authored-by: Patrick von Platen --- docs/source/internal/generation_utils.rst | 24 ++ docs/source/main_classes/model.rst | 8 +- src/transformers/__init__.py | 16 + .../generation_flax_logits_process.py | 192 +++++++++ src/transformers/generation_flax_utils.py | 388 ++++++++++++++++++ src/transformers/modeling_flax_utils.py | 3 +- .../models/gpt2/modeling_flax_gpt2.py | 59 +-- src/transformers/utils/dummy_flax_objects.py | 30 ++ tests/test_generation_flax_logits_process.py | 163 ++++++++ tests/test_generation_flax_utils.py | 170 ++++++++ tests/test_modeling_flax_gpt2.py | 123 +++--- 11 files changed, 1080 insertions(+), 96 deletions(-) create mode 100644 src/transformers/generation_flax_logits_process.py create mode 100644 src/transformers/generation_flax_utils.py create mode 100644 tests/test_generation_flax_logits_process.py create mode 100644 tests/test_generation_flax_utils.py diff --git a/docs/source/internal/generation_utils.rst b/docs/source/internal/generation_utils.rst index 9051a44721..fe066e456d 100644 --- a/docs/source/internal/generation_utils.rst +++ b/docs/source/internal/generation_utils.rst @@ -78,6 +78,9 @@ GreedySearchOutput .. autoclass:: transformers.generation_utils.GreedySearchEncoderDecoderOutput :members: +.. autoclass:: transformers.generation_flax_utils.FlaxGreedySearchOutput + :members: + SampleOutput ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -88,6 +91,9 @@ SampleOutput .. autoclass:: transformers.generation_utils.SampleEncoderDecoderOutput :members: +.. autoclass:: transformers.generation_flax_utils.FlaxSampleOutput + :members: + BeamSearchOutput ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -160,6 +166,24 @@ generation. .. autoclass:: transformers.InfNanRemoveLogitsProcessor :members: __call__ +.. autoclass:: transformers.FlaxLogitsProcessor + :members: __call__ + +.. autoclass:: transformers.FlaxLogitsProcessorList + :members: __call__ + +.. autoclass:: transformers.FlaxLogitsWarper + :members: __call__ + +.. autoclass:: transformers.FlaxTemperatureLogitsWarper + :members: __call__ + +.. autoclass:: transformers.FlaxTopPLogitsWarper + :members: __call__ + +.. autoclass:: transformers.FlaxTopKLogitsWarper + :members: __call__ + StoppingCriteria ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/main_classes/model.rst b/docs/source/main_classes/model.rst index 0f93bec8ce..e311a36eaa 100644 --- a/docs/source/main_classes/model.rst +++ b/docs/source/main_classes/model.rst @@ -26,8 +26,9 @@ are common among all the models to: The other methods that are common to each model are defined in :class:`~transformers.modeling_utils.ModuleUtilsMixin` (for the PyTorch models) and :class:`~transformers.modeling_tf_utils.TFModuleUtilsMixin` (for the TensorFlow models) or -for text generation, :class:`~transformers.generation_utils.GenerationMixin` (for the PyTorch models) and -:class:`~transformers.generation_tf_utils.TFGenerationMixin` (for the TensorFlow models) +for text generation, :class:`~transformers.generation_utils.GenerationMixin` (for the PyTorch models), +:class:`~transformers.generation_tf_utils.TFGenerationMixin` (for the TensorFlow models) and +:class:`~transformers.generation_flax_utils.FlaxGenerationMixin` (for the Flax/JAX models). PreTrainedModel @@ -74,6 +75,9 @@ Generation .. autoclass:: transformers.generation_tf_utils.TFGenerationMixin :members: +.. autoclass:: transformers.generation_flax_utils.FlaxGenerationMixin + :members: + Pushing to the Hub ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index da3d725006..26be362c5d 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1437,6 +1437,14 @@ else: # FLAX-backed objects if is_flax_available(): + _import_structure["generation_flax_logits_process"] = [ + "FlaxLogitsProcessor", + "FlaxLogitsProcessorList", + "FlaxLogitsWarper", + "FlaxTemperatureLogitsWarper", + "FlaxTopKLogitsWarper", + "FlaxTopPLogitsWarper", + ] _import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"] _import_structure["models.auto"].extend( [ @@ -2693,6 +2701,14 @@ if TYPE_CHECKING: from .utils.dummy_tf_objects import * if is_flax_available(): + from .generation_flax_logits_process import ( + FlaxLogitsProcessor, + FlaxLogitsProcessorList, + FlaxLogitsWarper, + FlaxTemperatureLogitsWarper, + FlaxTopKLogitsWarper, + FlaxTopPLogitsWarper, + ) from .modeling_flax_utils import FlaxPreTrainedModel from .models.auto import ( FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, diff --git a/src/transformers/generation_flax_logits_process.py b/src/transformers/generation_flax_logits_process.py new file mode 100644 index 0000000000..da4e77715c --- /dev/null +++ b/src/transformers/generation_flax_logits_process.py @@ -0,0 +1,192 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from abc import ABC + +import jax +import jax.lax as lax +import jax.numpy as jnp +import jaxlib.xla_extension as jax_xla + +from .file_utils import add_start_docstrings +from .utils.logging import get_logger + + +logger = get_logger(__name__) + + +LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See + :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + scores (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.vocab_size)`): + Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam + search or log softmax for each vocabulary token when using beam search + kwargs: + Additional logits processor specific kwargs. + + Return: + :obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.vocab_size)`: The processed prediction scores. + +""" + + +class FlaxLogitsProcessor(ABC): + """Abstract base class for all logit processors that can be applied during generation.""" + + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) + def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray: + """Flax method for processing logits.""" + raise NotImplementedError( + f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." + ) + + +class FlaxLogitsWarper(ABC): + """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" + + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) + def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray: + """Flax method for warping logits.""" + raise NotImplementedError( + f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." + ) + + +class FlaxLogitsProcessorList(list): + """ + This class can be used to create a list of :class:`~transformers.FlaxLogitsProcessor` or + :class:`~transformers.FlaxLogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits + from list and adds a specific `__call__` method to apply each :class:`~transformers.FlaxLogitsProcessor` or + :class:`~transformers.FlaxLogitsWarper` to the inputs. + """ + + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) + def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, **kwargs) -> jax_xla.DeviceArray: + for processor in self: + function_args = inspect.signature(processor.__call__).parameters + if len(function_args) > 2: + assert all( + arg in kwargs for arg in list(function_args.keys())[2:] + ), f"Make sure that all the required parameters: {list(function_args.keys())} for {processor.__class__} are passed to the logits processor." + scores = processor(input_ids, scores, **kwargs) + else: + scores = processor(input_ids, scores) + return scores + + +class FlaxTemperatureLogitsWarper(FlaxLogitsWarper): + r""" + :class:`transformers.LogitsWarper` for temperature (exponential scaling output probability distribution). + + Args: + temperature (:obj:`float`): + The value used to module the logits distribution. + """ + + def __init__(self, temperature: float): + if not isinstance(temperature, float) or not (temperature > 0): + raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}") + + self.temperature = temperature + + def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray: + scores = scores / self.temperature + return scores + + +class FlaxTopPLogitsWarper(FlaxLogitsWarper): + """ + :class:`transformers.LogitsWarper` that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= + prob_cut_off. + + Args: + top_p (:obj:`float`): + If set to < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are + kept for generation. + filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`): + All filtered values will be set to this float value. + min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1): + Minimum number of tokens that cannot be filtered. + """ + + def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0): + raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}") + + self.top_p = top_p + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray: + topk_scores, topk_indices = lax.top_k(scores, scores.shape[-1]) + + mask_scores = jnp.full_like(scores, self.filter_value) + cumulative_probs = jax.nn.softmax(topk_scores, axis=-1).cumsum(axis=-1) + score_mask = cumulative_probs < self.top_p + + # include the token that is higher than top_p as well + score_mask |= jax.ops.index_update(jnp.roll(score_mask, 1), jax.ops.index[:, 0], True) + + # min tokens to keep + score_mask = jax.ops.index_update(score_mask, jax.ops.index[:, : self.min_tokens_to_keep], True) + + topk_next_scores = jnp.where(score_mask, topk_scores, mask_scores) + next_scores = jax.lax.sort_key_val(topk_indices, topk_next_scores)[-1] + + return next_scores + + +class FlaxTopKLogitsWarper(FlaxLogitsWarper): + r""" + :class:`transformers.LogitsWarper` that performs top-k, i.e. restricting to the k highest probability elements. + + Args: + top_k (:obj:`int`): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + filter_value (:obj:`float`, `optional`, defaults to :obj:`-float("Inf")`): + All filtered values will be set to this float value. + min_tokens_to_keep (:obj:`int`, `optional`, defaults to 1): + Minimum number of tokens that cannot be filtered. + """ + + def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}") + + self.top_k = top_k + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray: + batch_size, vocab_size = scores.shape + next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value) + + topk = min(max(self.top_k, self.min_tokens_to_keep), scores.shape[-1]) # Safety check + topk_scores, topk_indices = lax.top_k(scores, topk) + shift = jnp.broadcast_to((jnp.arange(batch_size) * vocab_size)[:, None], (batch_size, topk)).flatten() + topk_scores_flat = topk_scores.flatten() + topk_indices_flat = topk_indices.flatten() + shift + + next_scores_flat = jax.ops.index_update(next_scores_flat, topk_indices_flat, topk_scores_flat) + next_scores = next_scores_flat.reshape(batch_size, vocab_size) + return next_scores diff --git a/src/transformers/generation_flax_utils.py b/src/transformers/generation_flax_utils.py new file mode 100644 index 0000000000..d12f8c6d49 --- /dev/null +++ b/src/transformers/generation_flax_utils.py @@ -0,0 +1,388 @@ +# coding=utf-8 +# Copyright 2021 The Google AI Flax Team Authors, and The HuggingFace Inc. team. +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Dict, Optional + +import flax +import jax +import jax.numpy as jnp +import jaxlib.xla_extension as jax_xla +from jax import lax + +from .file_utils import ModelOutput +from .generation_flax_logits_process import ( + FlaxLogitsProcessorList, + FlaxTemperatureLogitsWarper, + FlaxTopKLogitsWarper, + FlaxTopPLogitsWarper, +) +from .utils import logging + + +logger = logging.get_logger(__name__) + + +@flax.struct.dataclass +class FlaxGreedySearchOutput(ModelOutput): + """ + Flax Base class for outputs of decoder-only generation models using greedy search. + + + Args: + sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): + The generated sequences. If all batches finished early due to the :obj:`eos_token_id`, :obj:`sequences` is + padded to :obj:`max_length`. + """ + + sequences: jax_xla.DeviceArray = None + + +@flax.struct.dataclass +class FlaxSampleOutput(ModelOutput): + """ + Flax Base class for outputs of decoder-only generation models using sampling. + + + Args: + sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size, max_length)`): + The generated sequences. If all batches finished early due to the :obj:`eos_token_id`, :obj:`sequences` is + padded to :obj:`max_length`. + """ + + sequences: jax_xla.DeviceArray = None + + +@flax.struct.dataclass +class GreedyState: + cur_len: jax_xla.DeviceArray + sequences: jax_xla.DeviceArray + current_token: jax_xla.DeviceArray + is_sent_finished: jax_xla.DeviceArray + model_kwargs: Dict[str, jax_xla.DeviceArray] + + +@flax.struct.dataclass +class SampleState: + cur_len: jax_xla.DeviceArray + sequences: jax_xla.DeviceArray + current_token: jax_xla.DeviceArray + is_sent_finished: jax_xla.DeviceArray + prng_key: jax_xla.DeviceArray + model_kwargs: Dict[str, jax_xla.DeviceArray] + + +class FlaxGenerationMixin: + """ + A class containing all of the functions supporting generation, to be used as a mixin in + :class:`~transformers.FlaxPreTrainedModel`. + """ + + @staticmethod + def _run_loop_in_debug(cond_fn, body_fn, init_state): + """ + Run generation in untraced mode. This should only be used for debugging purposes. + """ + state = init_state + while cond_fn(state): + state = body_fn(state) + return state + + def generate( + self, + input_ids: jax_xla.DeviceArray, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + do_sample: Optional[bool] = None, + prng_key: Optional[jax_xla.DeviceArray] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + trace: bool = True, + **model_kwargs, + ): + r""" + Generates sequences for models with a language modeling head. The method currently supports greedy decoding, + and, multinomial sampling. + + Apart from :obj:`input_ids`, all the arguments below will default to the value of the attribute of the same + name inside the :class:`~transformers.PretrainedConfig` of the model. The default values indicated are the + default values of those config. + + Most of these parameters are explained in more detail in `this blog post + `__. + + Parameters: + + input_ids (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`, `optional`): + The sequence used as a prompt for the generation. + max_length (:obj:`int`, `optional`, defaults to 20): + The maximum length of the sequence to be generated. + do_sample (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to use sampling ; use greedy decoding otherwise. + temperature (:obj:`float`, `optional`, defaults to 1.0): + The value used to module the next token probabilities. + top_k (:obj:`int`, `optional`, defaults to 50): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p (:obj:`float`, `optional`, defaults to 1.0): + If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or + higher are kept for generation. + pad_token_id (:obj:`int`, `optional`): + The id of the `padding` token. + bos_token_id (:obj:`int`, `optional`): + The id of the `beginning-of-sequence` token. + eos_token_id (:obj:`int`, `optional`): + The id of the `end-of-sequence` token. + trace (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to + a considerably slower runtime. + model_kwargs: + Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model. + + Return: + :class:`~transformers.file_utils.ModelOutput`. + + Examples:: + >>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM + + >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") + >>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2") + >>> input_context = "The dog" + >>> # encode input context + >>> input_ids = tokenizer(input_context, return_tensors="jax").input_ids + >>> # generate candidates using sampling + >>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True) + >>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) + """ + # set init values + max_length = max_length if max_length is not None else self.config.max_length + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) + + do_sample = do_sample if do_sample is not None else self.config.do_sample + + if do_sample: + logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature) + return self._sample( + input_ids, + max_length, + pad_token_id, + eos_token_id, + prng_key, + logits_warper=logits_warper, + model_kwargs=model_kwargs, + trace=trace, + ) + else: + return self._greedy_search( + input_ids, max_length, pad_token_id, eos_token_id, trace=trace, model_kwargs=model_kwargs + ) + + def _get_logits_warper( + self, top_k: int = None, top_p: float = None, temperature: float = None + ) -> FlaxLogitsProcessorList: + """ + This class returns a :obj:`~transformers.FlaxLogitsProcessorList` list object that contains all relevant + :obj:`~transformers.FlaxLogitsWarper` instances used for multinomial sampling. + """ + + # init warp parameters + top_k = top_k if top_k is not None else self.config.top_k + top_p = top_p if top_p is not None else self.config.top_p + temperature = temperature if temperature is not None else self.config.temperature + # instantiate warpers list + warpers = FlaxLogitsProcessorList() + + # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files + # all samplers can be found in `generation_utils_samplers.py` + if temperature is not None and temperature != 1.0: + warpers.append(FlaxTemperatureLogitsWarper(temperature)) + if top_k is not None and top_k != 0: + warpers.append(FlaxTopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1)) + if top_p is not None and top_p < 1.0: + warpers.append(FlaxTopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1)) + + return warpers + + def _greedy_search( + self, + input_ids: None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + trace: bool = True, + model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None, + ): + # init values + max_length = max_length if max_length is not None else self.config.max_length + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + + batch_size, cur_len = input_ids.shape + + eos_token_id = jnp.array(eos_token_id) + pad_token_id = jnp.array(pad_token_id) + cur_len = jnp.array(cur_len) + + # per batch-item holding current token in loop. + sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32) + sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0)) + + # per batch-item state bit indicating if sentence has finished. + is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_) + + model = self + + # initialize model specific kwargs + model_kwargs = model.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs) + + # initialize state + state = GreedyState( + cur_len=cur_len, + sequences=sequences, + current_token=input_ids, + is_sent_finished=is_sent_finished, + model_kwargs=model_kwargs, + ) + + def greedy_search_cond_fn(state): + """state termination condition fn.""" + has_reached_max_length = state.cur_len == max_length + all_sequence_finished = jnp.all(state.is_sent_finished) + finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished) + return ~finish_generation + + def greedy_search_body_fn(state): + """state update fn.""" + model_outputs = model(state.current_token, **state.model_kwargs) + next_token = jnp.argmax(model_outputs.logits[:, -1], axis=-1) + + next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id) + next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished + next_token = next_token[:, None] + + next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len)) + next_model_kwargs = model.update_inputs_for_generation(model_outputs, model_kwargs) + + return GreedyState( + cur_len=state.cur_len + 1, + sequences=next_sequences, + current_token=next_token, + is_sent_finished=next_is_sent_finished, + model_kwargs=next_model_kwargs, + ) + + # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU + state = greedy_search_body_fn(state) + + if not trace: + state = self._run_loop_in_debug(greedy_search_cond_fn, greedy_search_body_fn, state) + else: + state = lax.while_loop(greedy_search_cond_fn, greedy_search_body_fn, state) + + return FlaxGreedySearchOutput(sequences=state.sequences) + + def _sample( + self, + input_ids: None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + prng_key: Optional[jax_xla.DeviceArray] = None, + model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None, + logits_warper: Optional[FlaxLogitsProcessorList] = None, + trace: bool = True, + ): + # init values + max_length = max_length if max_length is not None else self.config.max_length + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) + + batch_size, cur_len = input_ids.shape + + eos_token_id = jnp.array(eos_token_id) + pad_token_id = jnp.array(pad_token_id) + cur_len = jnp.array(cur_len) + + # per batch-item holding current token in loop. + sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32) + sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0)) + + # per batch-item state bit indicating if sentence has finished. + is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_) + + model = self + + # initialize model specific kwargs + model_kwargs = model.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs) + + # initialize state + state = SampleState( + cur_len=cur_len, + sequences=sequences, + current_token=input_ids, + is_sent_finished=is_sent_finished, + prng_key=prng_key, + model_kwargs=model_kwargs, + ) + + def sample_search_cond_fn(state): + """state termination condition fn.""" + has_reached_max_length = state.cur_len == max_length + all_sequence_finished = jnp.all(state.is_sent_finished) + finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished) + return ~finish_generation + + def sample_search_body_fn(state): + """state update fn.""" + prng_key, prng_key_next = jax.random.split(state.prng_key) + model_outputs = model(state.current_token, **state.model_kwargs) + + logits = model_outputs.logits[:, -1] + + # apply top_k, top_k, temperature + logits = logits_warper(state.sequences, logits) + + next_token = jax.random.categorical(prng_key, model_outputs.logits[:, -1], axis=-1) + + next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id) + next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished + next_token = next_token[:, None] + + next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len)) + next_model_kwargs = model.update_inputs_for_generation(model_outputs, model_kwargs) + + return SampleState( + cur_len=state.cur_len + 1, + sequences=next_sequences, + current_token=next_token, + is_sent_finished=next_is_sent_finished, + model_kwargs=next_model_kwargs, + prng_key=prng_key_next, + ) + + # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU + state = sample_search_body_fn(state) + + if not trace: + state = self._run_loop_in_debug(sample_search_cond_fn, sample_search_body_fn, state) + else: + state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state) + + return FlaxSampleOutput(sequences=state.sequences) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 3e33f66b27..0fc0298d6c 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -41,6 +41,7 @@ from .file_utils import ( is_remote_url, replace_return_docstrings, ) +from .generation_flax_utils import FlaxGenerationMixin from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict from .utils import logging @@ -57,7 +58,7 @@ ACT2FN = { } -class FlaxPreTrainedModel(PushToHubMixin): +class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): r""" Base class for all models. diff --git a/src/transformers/models/gpt2/modeling_flax_gpt2.py b/src/transformers/models/gpt2/modeling_flax_gpt2.py index 3d813791ee..19bac78c8a 100644 --- a/src/transformers/models/gpt2/modeling_flax_gpt2.py +++ b/src/transformers/models/gpt2/modeling_flax_gpt2.py @@ -20,7 +20,6 @@ import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, unfreeze from flax.linen import combine_masks, dot_product_attention, make_causal_mask -from flax.traverse_util import flatten_dict from jax import lax from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward @@ -322,13 +321,6 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel): module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) - @property - def _attn_layer_name(self): - attn_layer_key_tuple = ("h", "0", "attn") - if self.base_model_prefix in set(self.params.keys()): - attn_layer_key_tuple = (self.base_model_prefix,) + attn_layer_key_tuple - return attn_layer_key_tuple - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") @@ -381,28 +373,13 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel): batch_size, sequence_length = input_ids.shape if position_ids is None: - if past_key_values is not None and input_ids.shape[-1] == 1: - # if `past_key_values` are passed and input_ids are longer than 1, we are in cached auto-regressive generation. It has to be made sure that position_ids are set correctly - cache_shift = flatten_dict(unfreeze(past_key_values))[self._attn_layer_name + ("cache_index",)] - position_ids = jnp.broadcast_to( - jnp.arange(self.config.max_position_embeddings)[None, :], - (batch_size, self.config.max_position_embeddings), - ) - position_ids = lax.dynamic_slice(position_ids, (0, cache_shift), (batch_size, 1)) - else: - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + if past_key_values is not None: + raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") + + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) if attention_mask is None: - # if past_key_values are passed we need to create an attention_mask of the same length as `cache_length` - if past_key_values is not None: - cache_length = flatten_dict(unfreeze(past_key_values))[self._attn_layer_name + ("cached_key",)].shape[ - 1 - ] - else: - cache_length = sequence_length - - # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. But since GPT2 uses a causal mask, those positions are masked anyways. Thus we can create a single static attention_mask here, which is more efficient for compilation - attention_mask = jnp.ones((batch_size, cache_length)) + attention_mask = jnp.ones((batch_size, sequence_length)) # Handle any PRNG if needed rngs = {} @@ -627,6 +604,32 @@ class FlaxGPT2LMHeadModule(nn.Module): class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel): module_class = FlaxGPT2LMHeadModule + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): + # initializing the cache + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since GPT2 uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs + append_call_sample_docstring( FlaxGPT2LMHeadModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index acd9778436..0d35d3b695 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -2,6 +2,36 @@ from ..file_utils import requires_backends +class FlaxLogitsProcessor: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxLogitsProcessorList: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxLogitsWarper: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxTemperatureLogitsWarper: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxTopKLogitsWarper: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxTopPLogitsWarper: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxPreTrainedModel: def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) diff --git a/tests/test_generation_flax_logits_process.py b/tests/test_generation_flax_logits_process.py new file mode 100644 index 0000000000..4dacb5dc0a --- /dev/null +++ b/tests/test_generation_flax_logits_process.py @@ -0,0 +1,163 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Team Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a clone of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np + +from transformers import is_flax_available +from transformers.testing_utils import require_flax + +from .test_modeling_flax_common import ids_tensor + + +if is_flax_available(): + import jax + import jax.numpy as jnp + from transformers.generation_flax_logits_process import ( + FlaxLogitsProcessorList, + FlaxTemperatureLogitsWarper, + FlaxTopKLogitsWarper, + FlaxTopPLogitsWarper, + ) + + +@require_flax +class LogitsProcessorTest(unittest.TestCase): + def _get_uniform_logits(self, batch_size: int, length: int): + scores = np.ones((batch_size, length)) / length + return scores + + def test_temperature_dist_warper(self): + input_ids = None + length = 20 + + scores = self._get_uniform_logits(batch_size=2, length=length) + + # tweak scores to not be uniform anymore + scores[1, 5] = (1 / length) + 0.1 # peak, 1st batch + scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch + + # compute softmax + probs = jax.nn.softmax(scores, axis=-1) + + temp_dist_warper_sharper = FlaxTemperatureLogitsWarper(temperature=0.5) + temp_dist_warper_smoother = FlaxTemperatureLogitsWarper(temperature=1.3) + + warped_prob_sharp = jax.nn.softmax(temp_dist_warper_sharper(input_ids, scores.copy()), axis=-1) + warped_prob_smooth = jax.nn.softmax(temp_dist_warper_smoother(input_ids, scores.copy()), axis=-1) + + # uniform distribution stays uniform + self.assertTrue(jnp.allclose(probs[0, :], warped_prob_sharp[0, :], atol=1e-3)) + self.assertTrue(jnp.allclose(probs[0, :], warped_prob_smooth[0, :], atol=1e-3)) + + # sharp peaks get higher, valleys get lower + self.assertLess(probs[1, :].max(), warped_prob_sharp[1, :].max()) + self.assertGreater(probs[1, :].min(), warped_prob_sharp[1, :].min()) + + # smooth peaks get lower, valleys get higher + self.assertGreater(probs[1, :].max(), warped_prob_smooth[1, :].max()) + self.assertLess(probs[1, :].min(), warped_prob_smooth[1, :].min()) + + def test_top_k_dist_warper(self): + input_ids = None + vocab_size = 10 + batch_size = 2 + + # create ramp distribution + ramp_logits = np.broadcast_to(np.arange(vocab_size)[None, :], (batch_size, vocab_size)).copy() + ramp_logits[1:, : vocab_size // 2] = ramp_logits[1:, : vocab_size // 2] + vocab_size + + top_k_warp = FlaxTopKLogitsWarper(3) + + scores = top_k_warp(input_ids, ramp_logits) + + # check that correct tokens are filtered + self.assertListEqual(jnp.isinf(scores[0]).tolist(), 7 * [True] + 3 * [False]) + self.assertListEqual(jnp.isinf(scores[1]).tolist(), 2 * [True] + 3 * [False] + 5 * [True]) + + # check special case + length = 5 + top_k_warp_safety_check = FlaxTopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3) + + ramp_logits = np.broadcast_to(np.arange(length)[None, :], (batch_size, length)).copy() + scores = top_k_warp_safety_check(input_ids, ramp_logits) + + # min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified + self.assertListEqual((scores == 0.0).sum(axis=-1).tolist(), [2, 2]) + + def test_top_p_dist_warper(self): + input_ids = None + vocab_size = 10 + batch_size = 2 + + # create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper) + dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]])) + + top_p_warp = FlaxTopPLogitsWarper(0.7) + filtered_dist = np.exp(top_p_warp(input_ids, dist)) + + # dist should be filtered to keep min num values so that sum is >= 0.7 + # exp (-inf) => 0 + EXPECTED_FILTERED_DIST = np.array([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]]) + self.assertTrue(np.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)) + + # check edge cases with negative and extreme logits + ramp_logits = np.broadcast_to(np.arange(vocab_size)[None, :], (batch_size, vocab_size)).copy() - ( + vocab_size // 2 + ) + + # make ramp_logits more extreme + ramp_logits[1] = ramp_logits[1] * 100.0 + + # make sure at least 2 tokens are kept + top_p_warp = FlaxTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0) + filtered_dist = top_p_warp(input_ids, ramp_logits) + + # first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2. + self.assertListEqual((filtered_dist != 0.0).sum(axis=-1).tolist(), [3, 2]) + + def test_processor_list(self): + batch_size = 4 + sequence_length = 10 + vocab_size = 15 + + # dummy input_ids and scores + input_ids = ids_tensor((batch_size, sequence_length), vocab_size) + input_ids_comp = input_ids.copy() + + scores = self._get_uniform_logits(batch_size, vocab_size) + scores_comp = scores.copy() + + # instantiate all dist processors + temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5) + top_k_warp = FlaxTopKLogitsWarper(3) + top_p_warp = FlaxTopPLogitsWarper(0.8) + + # no processor list + scores = temp_dist_warp(input_ids, scores) + scores = top_k_warp(input_ids, scores) + scores = top_p_warp(input_ids, scores) + + # with processor list + processor = FlaxLogitsProcessorList([temp_dist_warp, top_k_warp, top_p_warp]) + scores_comp = processor(input_ids, scores_comp) + + # scores should be equal + self.assertTrue(jnp.allclose(scores, scores_comp, atol=1e-3)) + + # input_ids should never be changed + self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist()) diff --git a/tests/test_generation_flax_utils.py b/tests/test_generation_flax_utils.py new file mode 100644 index 0000000000..9b3e529c18 --- /dev/null +++ b/tests/test_generation_flax_utils.py @@ -0,0 +1,170 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import numpy as np + +from transformers import is_flax_available +from transformers.testing_utils import require_flax + + +if is_flax_available(): + import os + + import jax + import jax.numpy as jnp + from jax import jit + + os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8 + + +def ids_tensor(shape, vocab_size, rng=None): + """Creates a random int32 tensor of the shape within the vocab size.""" + if rng is None: + rng = random.Random() + + total_dims = 1 + for dim in shape: + total_dims *= dim + + values = [] + for _ in range(total_dims): + values.append(rng.randint(0, vocab_size - 1)) + + output = np.array(values, dtype=jnp.int32).reshape(shape) + + return output + + +def random_attention_mask(shape, rng=None): + attn_mask = ids_tensor(shape, vocab_size=2, rng=rng) + # make sure that at least one token is attended to for each batch + attn_mask[:, -1] = 1 + return attn_mask + + +@require_flax +class FlaxGenerationTesterMixin: + model_tester = None + all_generative_model_classes = () + + def _get_input_ids_and_config(self): + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + + # cut to half length & take max batch_size 3 + max_batch_size = 2 + sequence_length = inputs["input_ids"].shape[-1] // 2 + input_ids = inputs["input_ids"][:max_batch_size, :sequence_length] + + attention_mask = jnp.ones_like(input_ids) + attention_mask = attention_mask[:max_batch_size, :sequence_length] + + # generate max 5 tokens + max_length = input_ids.shape[-1] + 5 + if config.eos_token_id is not None and config.pad_token_id is None: + # hack to allow generate for models such as GPT2 as is done in `generate()` + config.pad_token_id = config.eos_token_id + return config, input_ids, attention_mask, max_length + + def test_greedy_generate(self): + config, input_ids, _, max_length = self._get_input_ids_and_config() + config.do_sample = False + config.max_length = max_length + + for model_class in self.all_generative_model_classes: + model = model_class(config) + + generation_outputs = model.generate(input_ids).sequences + self.assertEqual(generation_outputs.shape[-1], max_length) + + jit_generate = jit(model.generate) + jit_generation_outputs = jit_generate(input_ids).sequences + + self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist()) + + def test_sample_generate(self): + config, input_ids, _, max_length = self._get_input_ids_and_config() + config.do_sample = True + config.max_length = max_length + + for model_class in self.all_generative_model_classes: + model = model_class(config) + + generation_outputs = model.generate(input_ids).sequences + self.assertEqual(generation_outputs.shape[-1], max_length) + + jit_generate = jit(model.generate) + jit_generation_outputs = jit_generate(input_ids).sequences + + self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist()) + + def test_sample_generate_logits_warper(self): + config, input_ids, _, max_length = self._get_input_ids_and_config() + config.do_sample = True + config.max_length = max_length + config.temperature = 0.8 + config.top_k = 10 + config.top_p = 0.3 + + for model_class in self.all_generative_model_classes: + model = model_class(config) + + generation_outputs = model.generate(input_ids).sequences + self.assertEqual(generation_outputs.shape[-1], max_length) + + jit_generate = jit(model.generate) + jit_generation_outputs = jit_generate(input_ids).sequences + + self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist()) + + def test_greedy_generate_attn_mask(self): + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + # pad attention mask on the left + attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0) + + config.do_sample = False + config.max_length = max_length + + for model_class in self.all_generative_model_classes: + model = model_class(config) + + generation_outputs = model.generate(input_ids, attention_mask=attention_mask).sequences + self.assertEqual(generation_outputs.shape[-1], max_length) + + jit_generate = jit(model.generate) + jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences + + self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist()) + + def test_sample_generate_attn_mask(self): + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + + # pad attention mask on the left + attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0) + + config.do_sample = True + config.max_length = max_length + + for model_class in self.all_generative_model_classes: + model = model_class(config) + + generation_outputs = model.generate(input_ids, attention_mask=attention_mask).sequences + self.assertEqual(generation_outputs.shape[-1], max_length) + + jit_generate = jit(model.generate) + jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences + + self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist()) diff --git a/tests/test_modeling_flax_gpt2.py b/tests/test_modeling_flax_gpt2.py index f6abc74e42..c79fc5ef35 100644 --- a/tests/test_modeling_flax_gpt2.py +++ b/tests/test_modeling_flax_gpt2.py @@ -19,16 +19,16 @@ import unittest import numpy as np import transformers -from transformers import GPT2Config, is_flax_available, is_torch_available +from transformers import GPT2Config, GPT2Tokenizer, is_flax_available, is_torch_available from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow +from .test_generation_flax_utils import FlaxGenerationTesterMixin from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask if is_flax_available(): import jax import jax.numpy as jnp - from jax import lax from transformers.modeling_flax_pytorch_utils import ( convert_pytorch_state_dict_to_flax, load_flax_weights_in_pytorch_model, @@ -116,8 +116,25 @@ class FlaxGPT2ModelTester: model = model_class_name(config) past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length) - outputs_cache = model(input_ids[:, :-1], past_key_values=past_key_values) - outputs_cache_next = model(input_ids[:, -1:], past_key_values=outputs_cache.past_key_values) + attention_mask = jnp.ones((input_ids.shape[0], max_decoder_length), dtype="i4") + + position_ids = jnp.broadcast_to( + jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1) + ) + outputs_cache = model( + input_ids[:, :-1], + attention_mask=attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4") + outputs_cache_next = model( + input_ids[:, -1:], + attention_mask=attention_mask, + past_key_values=outputs_cache.past_key_values, + position_ids=position_ids, + ) outputs = model(input_ids) @@ -134,10 +151,22 @@ class FlaxGPT2ModelTester: ) past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length) + position_ids = jnp.broadcast_to( + jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1) + ) - outputs_cache = model(input_ids[:, :-1], attention_mask=attention_mask_cache, past_key_values=past_key_values) + outputs_cache = model( + input_ids[:, :-1], + attention_mask=attention_mask_cache, + past_key_values=past_key_values, + position_ids=position_ids, + ) + position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4") outputs_cache_next = model( - input_ids[:, -1:], past_key_values=outputs_cache.past_key_values, attention_mask=attention_mask_cache + input_ids[:, -1:], + past_key_values=outputs_cache.past_key_values, + attention_mask=attention_mask_cache, + position_ids=position_ids, ) outputs = model(input_ids, attention_mask=attention_mask) @@ -145,66 +174,12 @@ class FlaxGPT2ModelTester: diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))) self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}") - def check_use_cache_generation(self, config, input_ids): - prompt_length = 3 - model = FlaxGPT2LMHeadModel(config) - max_length = 10 - batch_size = 1 - - prompt_ids = input_ids[:1, :prompt_length] - - # put all generation logic into one function - def generate(prompt_ids): - def first_pass(prompt_ids): - logits, cache = model(prompt_ids, past_key_values=past_key_values)[:2] - next_token = jnp.argmax(logits[:, -1:], axis=-1) - return next_token, cache - - def greedy_search_cond_fn(state): - cur_len, _, _, _ = state - return ~(cur_len == max_length - 1) - - def greedy_search_body_fn(state): - cur_len, sequences, current_token, cache = state - next_sequences = lax.dynamic_update_slice(sequences, current_token, (0, cur_len)) - - next_logits, next_cache = model(current_token, past_key_values=cache)[:2] - next_token = jnp.argmax(next_logits, axis=-1) - - return cur_len + 1, next_sequences, next_token, next_cache - - # init tensor to be filled with generation result - init_sequences = jnp.zeros((batch_size, max_length), dtype="i4") - init_sequences = lax.dynamic_update_slice(init_sequences, prompt_ids, (0, 0)) - - # init past key values for cache - past_key_values = model.init_cache(batch_size, max_length) - - # first pass with long prompt - next_token, cache = first_pass(prompt_ids) - - # prepare state for generation loop - init_state = (jnp.array(prompt_length), init_sequences, next_token, cache) - - # fast generation - _, output_sequences, final_token, _ = lax.while_loop( - greedy_search_cond_fn, greedy_search_body_fn, init_state - ) - - # append last token - output_sequences = lax.dynamic_update_slice(output_sequences, final_token, (0, max_length - 1)) - - return output_sequences - - jit_generate = jax.jit(generate) - output_sequences = jit_generate(prompt_ids) - self.parent.assertEqual(output_sequences.shape, (1, max_length)) - @require_flax -class FlaxGPT2ModelTest(FlaxModelTesterMixin, unittest.TestCase): +class FlaxGPT2ModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase): all_model_classes = (FlaxGPT2Model, FlaxGPT2LMHeadModel) if is_flax_available() else () + all_generative_model_classes = (FlaxGPT2LMHeadModel,) if is_flax_available() else () def setUp(self): self.model_tester = FlaxGPT2ModelTester(self) @@ -221,9 +196,27 @@ class FlaxGPT2ModelTest(FlaxModelTesterMixin, unittest.TestCase): model_class_name, config, input_ids, attention_mask ) - def test_use_cache_generation(self): - config, input_ids, _ = self.model_tester.prepare_config_and_inputs() - self.model_tester.check_use_cache_generation(config, input_ids) + @slow + def test_batch_generation(self): + tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="", padding_side="left") + inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="jax", padding=True, truncation=True) + + model = FlaxGPT2LMHeadModel.from_pretrained("gpt2") + model.do_sample = False + model.config.pad_token_id = model.config.eos_token_id + + jit_generate = jax.jit(model.generate) + + output_sequences = jit_generate(inputs["input_ids"], attention_mask=inputs["attention_mask"]).sequences + + output_string = tokenizer.batch_decode(output_sequences, skip_special_tokens=True) + + expected_string = [ + "Hello this is a long string of words. I'm going to try to explain what I mean.", + "Hey, I'm not sure if I'm going to be able to do", + ] + + self.assertListEqual(output_string, expected_string) # overwrite from common since `attention_mask` in combination # with `causal_mask` behaves slighly differently