[Flax] Add Beam Search (#12131)
* fix_torch_device_generate_test * remove @ * push new logit processors * add processors * save first working version * save intermediate * finish * make style * make fix-copies * finish * Update tests/test_modeling_flax_bart.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Patrick von Platen <patrick@huggingface.co> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
committed by
GitHub
parent
802ffaff0d
commit
c3c39f7e84
@@ -186,6 +186,15 @@ generation.
|
|||||||
.. autoclass:: transformers.FlaxTopKLogitsWarper
|
.. autoclass:: transformers.FlaxTopKLogitsWarper
|
||||||
:members: __call__
|
:members: __call__
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxForcedBOSTokenLogitsProcessor
|
||||||
|
:members: __call__
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxForcedEOSTokenLogitsProcessor
|
||||||
|
:members: __call__
|
||||||
|
|
||||||
|
.. autoclass:: transformers.FlaxMinLengthLogitsProcessor
|
||||||
|
:members: __call__
|
||||||
|
|
||||||
|
|
||||||
StoppingCriteria
|
StoppingCriteria
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|||||||
@@ -1486,9 +1486,12 @@ else:
|
|||||||
# FLAX-backed objects
|
# FLAX-backed objects
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
_import_structure["generation_flax_logits_process"] = [
|
_import_structure["generation_flax_logits_process"] = [
|
||||||
|
"FlaxForcedBOSTokenLogitsProcessor",
|
||||||
|
"FlaxForcedEOSTokenLogitsProcessor",
|
||||||
"FlaxLogitsProcessor",
|
"FlaxLogitsProcessor",
|
||||||
"FlaxLogitsProcessorList",
|
"FlaxLogitsProcessorList",
|
||||||
"FlaxLogitsWarper",
|
"FlaxLogitsWarper",
|
||||||
|
"FlaxMinLengthLogitsProcessor",
|
||||||
"FlaxTemperatureLogitsWarper",
|
"FlaxTemperatureLogitsWarper",
|
||||||
"FlaxTopKLogitsWarper",
|
"FlaxTopKLogitsWarper",
|
||||||
"FlaxTopPLogitsWarper",
|
"FlaxTopPLogitsWarper",
|
||||||
@@ -2814,9 +2817,12 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
from .generation_flax_logits_process import (
|
from .generation_flax_logits_process import (
|
||||||
|
FlaxForcedBOSTokenLogitsProcessor,
|
||||||
|
FlaxForcedEOSTokenLogitsProcessor,
|
||||||
FlaxLogitsProcessor,
|
FlaxLogitsProcessor,
|
||||||
FlaxLogitsProcessorList,
|
FlaxLogitsProcessorList,
|
||||||
FlaxLogitsWarper,
|
FlaxLogitsWarper,
|
||||||
|
FlaxMinLengthLogitsProcessor,
|
||||||
FlaxTemperatureLogitsWarper,
|
FlaxTemperatureLogitsWarper,
|
||||||
FlaxTopKLogitsWarper,
|
FlaxTopKLogitsWarper,
|
||||||
FlaxTopPLogitsWarper,
|
FlaxTopPLogitsWarper,
|
||||||
|
|||||||
@@ -81,16 +81,18 @@ class FlaxLogitsProcessorList(list):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, **kwargs) -> jax_xla.DeviceArray:
|
def __call__(
|
||||||
|
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int, **kwargs
|
||||||
|
) -> jax_xla.DeviceArray:
|
||||||
for processor in self:
|
for processor in self:
|
||||||
function_args = inspect.signature(processor.__call__).parameters
|
function_args = inspect.signature(processor.__call__).parameters
|
||||||
if len(function_args) > 2:
|
if len(function_args) > 3:
|
||||||
assert all(
|
assert all(
|
||||||
arg in kwargs for arg in list(function_args.keys())[2:]
|
arg in kwargs for arg in list(function_args.keys())[2:]
|
||||||
), f"Make sure that all the required parameters: {list(function_args.keys())} for {processor.__class__} are passed to the logits processor."
|
), f"Make sure that all the required parameters: {list(function_args.keys())} for {processor.__class__} are passed to the logits processor."
|
||||||
scores = processor(input_ids, scores, **kwargs)
|
scores = processor(input_ids, scores, cur_len, **kwargs)
|
||||||
else:
|
else:
|
||||||
scores = processor(input_ids, scores)
|
scores = processor(input_ids, scores, cur_len)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
@@ -109,7 +111,9 @@ class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
|
|||||||
|
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
|
|
||||||
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
|
def __call__(
|
||||||
|
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
|
||||||
|
) -> jax_xla.DeviceArray:
|
||||||
scores = scores / self.temperature
|
scores = scores / self.temperature
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
@@ -137,7 +141,9 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
|
|||||||
self.filter_value = filter_value
|
self.filter_value = filter_value
|
||||||
self.min_tokens_to_keep = min_tokens_to_keep
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
|
||||||
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
|
def __call__(
|
||||||
|
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
|
||||||
|
) -> jax_xla.DeviceArray:
|
||||||
topk_scores, topk_indices = lax.top_k(scores, scores.shape[-1])
|
topk_scores, topk_indices = lax.top_k(scores, scores.shape[-1])
|
||||||
|
|
||||||
mask_scores = jnp.full_like(scores, self.filter_value)
|
mask_scores = jnp.full_like(scores, self.filter_value)
|
||||||
@@ -177,7 +183,9 @@ class FlaxTopKLogitsWarper(FlaxLogitsWarper):
|
|||||||
self.filter_value = filter_value
|
self.filter_value = filter_value
|
||||||
self.min_tokens_to_keep = min_tokens_to_keep
|
self.min_tokens_to_keep = min_tokens_to_keep
|
||||||
|
|
||||||
def __call__(self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray) -> jax_xla.DeviceArray:
|
def __call__(
|
||||||
|
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
|
||||||
|
) -> jax_xla.DeviceArray:
|
||||||
batch_size, vocab_size = scores.shape
|
batch_size, vocab_size = scores.shape
|
||||||
next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value)
|
next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value)
|
||||||
|
|
||||||
@@ -190,3 +198,94 @@ class FlaxTopKLogitsWarper(FlaxLogitsWarper):
|
|||||||
next_scores_flat = jax.ops.index_update(next_scores_flat, topk_indices_flat, topk_scores_flat)
|
next_scores_flat = jax.ops.index_update(next_scores_flat, topk_indices_flat, topk_scores_flat)
|
||||||
next_scores = next_scores_flat.reshape(batch_size, vocab_size)
|
next_scores = next_scores_flat.reshape(batch_size, vocab_size)
|
||||||
return next_scores
|
return next_scores
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxForcedBOSTokenLogitsProcessor(FlaxLogitsProcessor):
|
||||||
|
r"""
|
||||||
|
:class:`~transformers.FlaxLogitsProcessor` that enforces the specified token as the first generated token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bos_token_id (:obj:`int`):
|
||||||
|
The id of the token to force as the first generated token.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, bos_token_id: int):
|
||||||
|
self.bos_token_id = bos_token_id
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
|
||||||
|
) -> jax_xla.DeviceArray:
|
||||||
|
new_scores = jnp.full(scores.shape, -float("inf"))
|
||||||
|
|
||||||
|
apply_penalty = 1 - jnp.bool_(cur_len - 1)
|
||||||
|
|
||||||
|
scores = jnp.where(
|
||||||
|
apply_penalty, jax.ops.index_update(new_scores, jax.ops.index[:, self.bos_token_id], 0), scores
|
||||||
|
)
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxForcedEOSTokenLogitsProcessor(FlaxLogitsProcessor):
|
||||||
|
r"""
|
||||||
|
:class:`~transformers.FlaxLogitsProcessor` that enforces the specified token as the last generated token when
|
||||||
|
:obj:`max_length` is reached.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_length (:obj:`int`):
|
||||||
|
The maximum length of the sequence to be generated.
|
||||||
|
eos_token_id (:obj:`int`):
|
||||||
|
The id of the token to force as the last generated token when :obj:`max_length` is reached.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_length: int, eos_token_id: int):
|
||||||
|
self.max_length = max_length
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
|
||||||
|
) -> jax_xla.DeviceArray:
|
||||||
|
new_scores = jnp.full(scores.shape, -float("inf"))
|
||||||
|
|
||||||
|
apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1)
|
||||||
|
|
||||||
|
scores = jnp.where(
|
||||||
|
apply_penalty, jax.ops.index_update(new_scores, jax.ops.index[:, self.eos_token_id], 0), scores
|
||||||
|
)
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor):
|
||||||
|
r"""
|
||||||
|
:class:`transformers.FlaxLogitsProcessor` enforcing a min-length by setting EOS probability to 0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_length (:obj:`int`):
|
||||||
|
The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`.
|
||||||
|
eos_token_id (:obj:`int`):
|
||||||
|
The id of the `end-of-sequence` token.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, min_length: int, eos_token_id: int):
|
||||||
|
if not isinstance(min_length, int) or min_length < 0:
|
||||||
|
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
|
||||||
|
|
||||||
|
if not isinstance(eos_token_id, int) or eos_token_id < 0:
|
||||||
|
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
|
||||||
|
|
||||||
|
self.min_length = min_length
|
||||||
|
self.eos_token_id = eos_token_id
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, input_ids: jax_xla.DeviceArray, scores: jax_xla.DeviceArray, cur_len: int
|
||||||
|
) -> jax_xla.DeviceArray:
|
||||||
|
|
||||||
|
# create boolean flag to decide if min length penalty should be applied
|
||||||
|
apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1)
|
||||||
|
|
||||||
|
scores = jnp.where(
|
||||||
|
apply_penalty, jax.ops.index_update(scores, jax.ops.index[:, self.eos_token_id], -float("inf")), scores
|
||||||
|
)
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|||||||
@@ -17,6 +17,8 @@
|
|||||||
|
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import flax
|
import flax
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
@@ -25,7 +27,10 @@ from jax import lax
|
|||||||
|
|
||||||
from .file_utils import ModelOutput
|
from .file_utils import ModelOutput
|
||||||
from .generation_flax_logits_process import (
|
from .generation_flax_logits_process import (
|
||||||
|
FlaxForcedBOSTokenLogitsProcessor,
|
||||||
|
FlaxForcedEOSTokenLogitsProcessor,
|
||||||
FlaxLogitsProcessorList,
|
FlaxLogitsProcessorList,
|
||||||
|
FlaxMinLengthLogitsProcessor,
|
||||||
FlaxTemperatureLogitsWarper,
|
FlaxTemperatureLogitsWarper,
|
||||||
FlaxTopKLogitsWarper,
|
FlaxTopKLogitsWarper,
|
||||||
FlaxTopPLogitsWarper,
|
FlaxTopPLogitsWarper,
|
||||||
@@ -43,9 +48,8 @@ class FlaxGreedySearchOutput(ModelOutput):
|
|||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`):
|
||||||
The generated sequences. If all batches finished early due to the :obj:`eos_token_id`, :obj:`sequences` is
|
The generated sequences.
|
||||||
padded to :obj:`max_length`.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences: jax_xla.DeviceArray = None
|
sequences: jax_xla.DeviceArray = None
|
||||||
@@ -58,19 +62,35 @@ class FlaxSampleOutput(ModelOutput):
|
|||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sequences (:obj:`torch.LongTensor` of shape :obj:`(batch_size, max_length)`):
|
sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`):
|
||||||
The generated sequences. If all batches finished early due to the :obj:`eos_token_id`, :obj:`sequences` is
|
The generated sequences.
|
||||||
padded to :obj:`max_length`.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sequences: jax_xla.DeviceArray = None
|
sequences: jax_xla.DeviceArray = None
|
||||||
|
|
||||||
|
|
||||||
|
@flax.struct.dataclass
|
||||||
|
class FlaxBeamSearchOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Flax Base class for outputs of decoder-only generation models using greedy search.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sequences (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, max_length)`):
|
||||||
|
The generated sequences.
|
||||||
|
scores (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size,)`):
|
||||||
|
The scores (log probabilites) of the generated sequences.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sequences: jax_xla.DeviceArray = None
|
||||||
|
scores: jax_xla.DeviceArray = None
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
class GreedyState:
|
class GreedyState:
|
||||||
cur_len: jax_xla.DeviceArray
|
cur_len: jax_xla.DeviceArray
|
||||||
sequences: jax_xla.DeviceArray
|
sequences: jax_xla.DeviceArray
|
||||||
current_token: jax_xla.DeviceArray
|
running_token: jax_xla.DeviceArray
|
||||||
is_sent_finished: jax_xla.DeviceArray
|
is_sent_finished: jax_xla.DeviceArray
|
||||||
model_kwargs: Dict[str, jax_xla.DeviceArray]
|
model_kwargs: Dict[str, jax_xla.DeviceArray]
|
||||||
|
|
||||||
@@ -79,12 +99,23 @@ class GreedyState:
|
|||||||
class SampleState:
|
class SampleState:
|
||||||
cur_len: jax_xla.DeviceArray
|
cur_len: jax_xla.DeviceArray
|
||||||
sequences: jax_xla.DeviceArray
|
sequences: jax_xla.DeviceArray
|
||||||
current_token: jax_xla.DeviceArray
|
running_token: jax_xla.DeviceArray
|
||||||
is_sent_finished: jax_xla.DeviceArray
|
is_sent_finished: jax_xla.DeviceArray
|
||||||
prng_key: jax_xla.DeviceArray
|
prng_key: jax_xla.DeviceArray
|
||||||
model_kwargs: Dict[str, jax_xla.DeviceArray]
|
model_kwargs: Dict[str, jax_xla.DeviceArray]
|
||||||
|
|
||||||
|
|
||||||
|
@flax.struct.dataclass
|
||||||
|
class BeamSearchState:
|
||||||
|
cur_len: jax_xla.DeviceArray
|
||||||
|
running_sequences: jax_xla.DeviceArray
|
||||||
|
running_scores: jax_xla.DeviceArray
|
||||||
|
sequences: jax_xla.DeviceArray
|
||||||
|
scores: jax_xla.DeviceArray
|
||||||
|
is_sent_finished: jax_xla.DeviceArray
|
||||||
|
model_kwargs: Dict[str, jax_xla.DeviceArray]
|
||||||
|
|
||||||
|
|
||||||
class FlaxGenerationMixin:
|
class FlaxGenerationMixin:
|
||||||
"""
|
"""
|
||||||
A class containing all of the functions supporting generation, to be used as a mixin in
|
A class containing all of the functions supporting generation, to be used as a mixin in
|
||||||
@@ -110,6 +141,10 @@ class FlaxGenerationMixin:
|
|||||||
model_kwargs["encoder_outputs"] = self.encode(input_ids, return_dict=True, **encoder_kwargs)
|
model_kwargs["encoder_outputs"] = self.encode(input_ids, return_dict=True, **encoder_kwargs)
|
||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _expand_to_num_beams(tensor, num_beams):
|
||||||
|
return jnp.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:])
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
input_ids: jax_xla.DeviceArray,
|
input_ids: jax_xla.DeviceArray,
|
||||||
@@ -123,6 +158,13 @@ class FlaxGenerationMixin:
|
|||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
|
num_beams: Optional[int] = None,
|
||||||
|
no_repeat_ngram_size: Optional[int] = None,
|
||||||
|
min_length: Optional[int] = None,
|
||||||
|
forced_bos_token_id: Optional[int] = None,
|
||||||
|
forced_eos_token_id: Optional[int] = None,
|
||||||
|
length_penalty: Optional[float] = None,
|
||||||
|
early_stopping: Optional[bool] = None,
|
||||||
trace: bool = True,
|
trace: bool = True,
|
||||||
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
@@ -159,6 +201,8 @@ class FlaxGenerationMixin:
|
|||||||
The id of the `beginning-of-sequence` token.
|
The id of the `beginning-of-sequence` token.
|
||||||
eos_token_id (:obj:`int`, `optional`):
|
eos_token_id (:obj:`int`, `optional`):
|
||||||
The id of the `end-of-sequence` token.
|
The id of the `end-of-sequence` token.
|
||||||
|
num_beams (:obj:`int`, `optional`, defaults to 1):
|
||||||
|
Number of beams for beam search. 1 means no beam search.
|
||||||
decoder_start_token_id (:obj:`int`, `optional`):
|
decoder_start_token_id (:obj:`int`, `optional`):
|
||||||
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
|
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
|
||||||
trace (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
trace (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
@@ -204,9 +248,27 @@ class FlaxGenerationMixin:
|
|||||||
input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
|
input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
|
||||||
|
|
||||||
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
||||||
|
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
||||||
|
|
||||||
if do_sample:
|
if not do_sample and num_beams == 1:
|
||||||
|
logits_processor = self._get_logits_processor(
|
||||||
|
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
|
||||||
|
)
|
||||||
|
return self._greedy_search(
|
||||||
|
input_ids,
|
||||||
|
max_length,
|
||||||
|
pad_token_id,
|
||||||
|
eos_token_id,
|
||||||
|
logits_processor=logits_processor,
|
||||||
|
trace=trace,
|
||||||
|
params=params,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
)
|
||||||
|
elif do_sample and num_beams == 1:
|
||||||
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature)
|
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature)
|
||||||
|
logits_processor = self._get_logits_processor(
|
||||||
|
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
|
||||||
|
)
|
||||||
return self._sample(
|
return self._sample(
|
||||||
input_ids,
|
input_ids,
|
||||||
max_length,
|
max_length,
|
||||||
@@ -214,20 +276,43 @@ class FlaxGenerationMixin:
|
|||||||
eos_token_id,
|
eos_token_id,
|
||||||
prng_key,
|
prng_key,
|
||||||
logits_warper=logits_warper,
|
logits_warper=logits_warper,
|
||||||
|
logits_processor=logits_processor,
|
||||||
|
trace=trace,
|
||||||
|
params=params,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
)
|
||||||
|
elif not do_sample and num_beams > 1:
|
||||||
|
# broadcast input_ids & encoder_outputs
|
||||||
|
input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams)
|
||||||
|
|
||||||
|
if "encoder_outputs" in model_kwargs:
|
||||||
|
model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams(
|
||||||
|
model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=num_beams
|
||||||
|
)
|
||||||
|
|
||||||
|
if "attention_mask" in model_kwargs:
|
||||||
|
model_kwargs["attention_mask"] = self._expand_to_num_beams(
|
||||||
|
model_kwargs["attention_mask"], num_beams=num_beams
|
||||||
|
)
|
||||||
|
|
||||||
|
logits_processor = self._get_logits_processor(
|
||||||
|
no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._beam_search(
|
||||||
|
input_ids,
|
||||||
|
max_length,
|
||||||
|
pad_token_id,
|
||||||
|
eos_token_id,
|
||||||
|
length_penalty=length_penalty,
|
||||||
|
early_stopping=early_stopping,
|
||||||
|
logits_processor=logits_processor,
|
||||||
trace=trace,
|
trace=trace,
|
||||||
params=params,
|
params=params,
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self._greedy_search(
|
raise NotImplementedError("`Beam sampling is currently not implemented.")
|
||||||
input_ids,
|
|
||||||
max_length,
|
|
||||||
pad_token_id,
|
|
||||||
eos_token_id,
|
|
||||||
trace=trace,
|
|
||||||
params=params,
|
|
||||||
model_kwargs=model_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_logits_warper(
|
def _get_logits_warper(
|
||||||
self, top_k: int = None, top_p: float = None, temperature: float = None
|
self, top_k: int = None, top_p: float = None, temperature: float = None
|
||||||
@@ -255,12 +340,51 @@ class FlaxGenerationMixin:
|
|||||||
|
|
||||||
return warpers
|
return warpers
|
||||||
|
|
||||||
|
def _get_logits_processor(
|
||||||
|
self,
|
||||||
|
no_repeat_ngram_size: int,
|
||||||
|
min_length: int,
|
||||||
|
max_length: int,
|
||||||
|
eos_token_id: int,
|
||||||
|
forced_bos_token_id: int,
|
||||||
|
forced_eos_token_id: int,
|
||||||
|
) -> FlaxLogitsProcessorList:
|
||||||
|
"""
|
||||||
|
This class returns a :obj:`~transformers.FlaxLogitsProcessorList` list object that contains all relevant
|
||||||
|
:obj:`~transformers.FlaxLogitsProcessor` instances used to modify the scores of the language model head.
|
||||||
|
"""
|
||||||
|
processors = FlaxLogitsProcessorList()
|
||||||
|
|
||||||
|
# init warp parameters
|
||||||
|
no_repeat_ngram_size = (
|
||||||
|
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
|
||||||
|
)
|
||||||
|
min_length = min_length if min_length is not None else self.config.min_length
|
||||||
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||||
|
forced_bos_token_id = (
|
||||||
|
forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
|
||||||
|
)
|
||||||
|
forced_eos_token_id = (
|
||||||
|
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
|
||||||
|
# all samplers can be found in `generation_utils_samplers.py`
|
||||||
|
if min_length is not None and eos_token_id is not None and min_length > -1:
|
||||||
|
processors.append(FlaxMinLengthLogitsProcessor(min_length, eos_token_id))
|
||||||
|
if forced_bos_token_id is not None:
|
||||||
|
processors.append(FlaxForcedBOSTokenLogitsProcessor(forced_bos_token_id))
|
||||||
|
if forced_eos_token_id is not None:
|
||||||
|
processors.append(FlaxForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
|
||||||
|
return processors
|
||||||
|
|
||||||
def _greedy_search(
|
def _greedy_search(
|
||||||
self,
|
self,
|
||||||
input_ids: None,
|
input_ids: None,
|
||||||
max_length: Optional[int] = None,
|
max_length: Optional[int] = None,
|
||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[int] = None,
|
||||||
|
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
||||||
trace: bool = True,
|
trace: bool = True,
|
||||||
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
||||||
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
||||||
@@ -293,7 +417,7 @@ class FlaxGenerationMixin:
|
|||||||
state = GreedyState(
|
state = GreedyState(
|
||||||
cur_len=cur_len,
|
cur_len=cur_len,
|
||||||
sequences=sequences,
|
sequences=sequences,
|
||||||
current_token=input_ids,
|
running_token=input_ids,
|
||||||
is_sent_finished=is_sent_finished,
|
is_sent_finished=is_sent_finished,
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
)
|
)
|
||||||
@@ -307,8 +431,13 @@ class FlaxGenerationMixin:
|
|||||||
|
|
||||||
def greedy_search_body_fn(state):
|
def greedy_search_body_fn(state):
|
||||||
"""state update fn."""
|
"""state update fn."""
|
||||||
model_outputs = model(state.current_token, params=params, **state.model_kwargs)
|
model_outputs = model(state.running_token, params=params, **state.model_kwargs)
|
||||||
next_token = jnp.argmax(model_outputs.logits[:, -1], axis=-1)
|
logits = model_outputs.logits[:, -1]
|
||||||
|
|
||||||
|
# apply min_length, ...
|
||||||
|
logits = logits_processor(state.sequences, logits, state.cur_len)
|
||||||
|
|
||||||
|
next_token = jnp.argmax(logits, axis=-1)
|
||||||
|
|
||||||
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
|
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
|
||||||
next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished
|
next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished
|
||||||
@@ -319,7 +448,7 @@ class FlaxGenerationMixin:
|
|||||||
return GreedyState(
|
return GreedyState(
|
||||||
cur_len=state.cur_len + 1,
|
cur_len=state.cur_len + 1,
|
||||||
sequences=next_sequences,
|
sequences=next_sequences,
|
||||||
current_token=next_token,
|
running_token=next_token,
|
||||||
is_sent_finished=next_is_sent_finished,
|
is_sent_finished=next_is_sent_finished,
|
||||||
model_kwargs=next_model_kwargs,
|
model_kwargs=next_model_kwargs,
|
||||||
)
|
)
|
||||||
@@ -342,6 +471,7 @@ class FlaxGenerationMixin:
|
|||||||
pad_token_id: Optional[int] = None,
|
pad_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[int] = None,
|
||||||
prng_key: Optional[jax_xla.DeviceArray] = None,
|
prng_key: Optional[jax_xla.DeviceArray] = None,
|
||||||
|
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
||||||
logits_warper: Optional[FlaxLogitsProcessorList] = None,
|
logits_warper: Optional[FlaxLogitsProcessorList] = None,
|
||||||
trace: bool = True,
|
trace: bool = True,
|
||||||
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
||||||
@@ -377,7 +507,7 @@ class FlaxGenerationMixin:
|
|||||||
state = SampleState(
|
state = SampleState(
|
||||||
cur_len=cur_len,
|
cur_len=cur_len,
|
||||||
sequences=sequences,
|
sequences=sequences,
|
||||||
current_token=input_ids,
|
running_token=input_ids,
|
||||||
is_sent_finished=is_sent_finished,
|
is_sent_finished=is_sent_finished,
|
||||||
prng_key=prng_key,
|
prng_key=prng_key,
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
@@ -393,12 +523,14 @@ class FlaxGenerationMixin:
|
|||||||
def sample_search_body_fn(state):
|
def sample_search_body_fn(state):
|
||||||
"""state update fn."""
|
"""state update fn."""
|
||||||
prng_key, prng_key_next = jax.random.split(state.prng_key)
|
prng_key, prng_key_next = jax.random.split(state.prng_key)
|
||||||
model_outputs = model(state.current_token, params=params, **state.model_kwargs)
|
model_outputs = model(state.running_token, params=params, **state.model_kwargs)
|
||||||
|
|
||||||
logits = model_outputs.logits[:, -1]
|
logits = model_outputs.logits[:, -1]
|
||||||
|
|
||||||
|
# apply min_length, ...
|
||||||
|
logits = logits_processor(state.sequences, logits, state.cur_len)
|
||||||
# apply top_k, top_k, temperature
|
# apply top_k, top_k, temperature
|
||||||
logits = logits_warper(state.sequences, logits)
|
logits = logits_warper(logits, logits, state.cur_len)
|
||||||
|
|
||||||
next_token = jax.random.categorical(prng_key, model_outputs.logits[:, -1], axis=-1)
|
next_token = jax.random.categorical(prng_key, model_outputs.logits[:, -1], axis=-1)
|
||||||
|
|
||||||
@@ -412,7 +544,7 @@ class FlaxGenerationMixin:
|
|||||||
return SampleState(
|
return SampleState(
|
||||||
cur_len=state.cur_len + 1,
|
cur_len=state.cur_len + 1,
|
||||||
sequences=next_sequences,
|
sequences=next_sequences,
|
||||||
current_token=next_token,
|
running_token=next_token,
|
||||||
is_sent_finished=next_is_sent_finished,
|
is_sent_finished=next_is_sent_finished,
|
||||||
model_kwargs=next_model_kwargs,
|
model_kwargs=next_model_kwargs,
|
||||||
prng_key=prng_key_next,
|
prng_key=prng_key_next,
|
||||||
@@ -428,3 +560,251 @@ class FlaxGenerationMixin:
|
|||||||
state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)
|
state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)
|
||||||
|
|
||||||
return FlaxSampleOutput(sequences=state.sequences)
|
return FlaxSampleOutput(sequences=state.sequences)
|
||||||
|
|
||||||
|
def _beam_search(
|
||||||
|
self,
|
||||||
|
input_ids: None,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
pad_token_id: Optional[int] = None,
|
||||||
|
eos_token_id: Optional[int] = None,
|
||||||
|
length_penalty: Optional[float] = None,
|
||||||
|
early_stopping: Optional[bool] = None,
|
||||||
|
logits_processor: Optional[FlaxLogitsProcessorList] = None,
|
||||||
|
trace: bool = True,
|
||||||
|
params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
||||||
|
model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This beam search function is heavily inspired by Flax's official example:
|
||||||
|
https://github.com/google/flax/blob/master/examples/wmt/train.py#L254
|
||||||
|
"""
|
||||||
|
|
||||||
|
def flatten_beam_dim(tensor):
|
||||||
|
"""Flattens the first two dimensions of a non-scalar array."""
|
||||||
|
# ignore scalars (e.g. cache index)
|
||||||
|
if tensor.ndim == 0:
|
||||||
|
return tensor
|
||||||
|
return tensor.reshape((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
|
||||||
|
|
||||||
|
def unflatten_beam_dim(tensor, batch_size, num_beams):
|
||||||
|
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
|
||||||
|
# ignore scalars (e.g. cache index)
|
||||||
|
if tensor.ndim == 0:
|
||||||
|
return tensor
|
||||||
|
return tensor.reshape((batch_size, num_beams) + tensor.shape[1:])
|
||||||
|
|
||||||
|
def gather_beams(nested, beam_indices, batch_size, new_num_beams):
|
||||||
|
"""
|
||||||
|
Gathers the beam slices indexed by beam_indices into new beam array.
|
||||||
|
"""
|
||||||
|
batch_indices = jnp.reshape(
|
||||||
|
jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams)
|
||||||
|
)
|
||||||
|
|
||||||
|
def gather_fn(tensor):
|
||||||
|
# ignore scalars (e.g. cache index)
|
||||||
|
if tensor.ndim == 0:
|
||||||
|
return tensor
|
||||||
|
else:
|
||||||
|
return tensor[batch_indices, beam_indices]
|
||||||
|
|
||||||
|
return jax.tree_map(gather_fn, nested)
|
||||||
|
|
||||||
|
# init values
|
||||||
|
max_length = max_length if max_length is not None else self.config.max_length
|
||||||
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||||
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||||
|
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
||||||
|
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
||||||
|
|
||||||
|
batch_size, num_beams, cur_len = input_ids.shape
|
||||||
|
|
||||||
|
eos_token_id = jnp.array(eos_token_id)
|
||||||
|
pad_token_id = jnp.array(pad_token_id)
|
||||||
|
cur_len = jnp.array(cur_len)
|
||||||
|
|
||||||
|
# per batch,beam-item holding current token in loop.
|
||||||
|
sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
|
||||||
|
running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
|
||||||
|
running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0))
|
||||||
|
|
||||||
|
# per batch,beam-item state bit indicating if sentence has finished.
|
||||||
|
is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_)
|
||||||
|
|
||||||
|
# per batch,beam-item score, logprobs
|
||||||
|
running_scores = jnp.tile(jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1])
|
||||||
|
scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7)
|
||||||
|
|
||||||
|
# For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
|
||||||
|
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
|
||||||
|
model = self.decode if self.config.is_encoder_decoder else self
|
||||||
|
|
||||||
|
# flatten beam dim
|
||||||
|
if "encoder_outputs" in model_kwargs:
|
||||||
|
model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim(
|
||||||
|
model_kwargs["encoder_outputs"]["last_hidden_state"]
|
||||||
|
)
|
||||||
|
if "attention_mask" in model_kwargs:
|
||||||
|
model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"])
|
||||||
|
|
||||||
|
# initialize model specific kwargs
|
||||||
|
model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs)
|
||||||
|
|
||||||
|
# initialize state
|
||||||
|
state = BeamSearchState(
|
||||||
|
cur_len=cur_len,
|
||||||
|
running_sequences=running_sequences,
|
||||||
|
running_scores=running_scores,
|
||||||
|
sequences=sequences,
|
||||||
|
scores=scores,
|
||||||
|
is_sent_finished=is_sent_finished,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def beam_search_cond_fn(state):
|
||||||
|
"""beam search state termination condition fn."""
|
||||||
|
|
||||||
|
# 1. is less than max length?
|
||||||
|
not_max_length_yet = state.cur_len < max_length
|
||||||
|
|
||||||
|
# 2. can the new beams still improve?
|
||||||
|
best_running_score = state.running_scores[:, -1:] / (max_length ** length_penalty)
|
||||||
|
worst_finished_score = jnp.where(
|
||||||
|
state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
|
||||||
|
)
|
||||||
|
improvement_still_possible = jnp.all(worst_finished_score < best_running_score)
|
||||||
|
|
||||||
|
# 3. is there still a beam that has not finished?
|
||||||
|
still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping)
|
||||||
|
|
||||||
|
return not_max_length_yet & still_open_beam & improvement_still_possible
|
||||||
|
|
||||||
|
def beam_search_body_fn(state):
|
||||||
|
"""beam search state update fn."""
|
||||||
|
# 1. Forward current tokens
|
||||||
|
# Collect the current position slice along length to feed the fast
|
||||||
|
# autoregressive decoder model. Flatten the beam dimension into batch
|
||||||
|
# dimension for feeding into the model.
|
||||||
|
# unflatten beam dimension
|
||||||
|
# Unflatten beam dimension in attention cache arrays
|
||||||
|
input_token = flatten_beam_dim(
|
||||||
|
lax.dynamic_slice(state.running_sequences, (0, 0, state.cur_len - 1), (batch_size, num_beams, 1))
|
||||||
|
)
|
||||||
|
model_outputs = model(input_token, params=params, **state.model_kwargs)
|
||||||
|
logits = unflatten_beam_dim(model_outputs.logits[:, 0], batch_size, num_beams)
|
||||||
|
cache = jax.tree_map(
|
||||||
|
lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Compute log probs
|
||||||
|
# get log probabilities from logits,
|
||||||
|
# process logits with processors (*e.g.* min_length, ...), and
|
||||||
|
# add new logprobs to existing running logprobs scores.
|
||||||
|
log_probs = jax.nn.log_softmax(logits)
|
||||||
|
log_probs = logits_processor(
|
||||||
|
flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len
|
||||||
|
)
|
||||||
|
log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
|
||||||
|
log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2)
|
||||||
|
vocab_size = log_probs.shape[2]
|
||||||
|
log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))
|
||||||
|
|
||||||
|
# 3. Retrieve top-K
|
||||||
|
# Each item in batch has num_beams * vocab_size candidate sequences.
|
||||||
|
# For each item, get the top 2*k candidates with the highest log-
|
||||||
|
# probabilities. We gather the top 2*K beams here so that even if the best
|
||||||
|
# K sequences reach EOS simultaneously, we have another K sequences
|
||||||
|
# remaining to continue the live beam search.
|
||||||
|
# Gather the top 2*K scores from _all_ beams.
|
||||||
|
# Gather 2*k top beams.
|
||||||
|
# Recover the beam index by floor division.
|
||||||
|
# Recover token id by modulo division and expand Id array for broadcasting.
|
||||||
|
# Update sequences for the 2*K top-k new sequences.
|
||||||
|
beams_to_keep = 2 * num_beams
|
||||||
|
topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep)
|
||||||
|
topk_beam_indices = topk_indices // vocab_size
|
||||||
|
topk_running_sequences = gather_beams(
|
||||||
|
state.running_sequences, topk_beam_indices, batch_size, beams_to_keep
|
||||||
|
)
|
||||||
|
topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
|
||||||
|
topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len))
|
||||||
|
|
||||||
|
# 4. Check which sequences have ended
|
||||||
|
# Update current sequences:
|
||||||
|
# Did any of these sequences reach an end marker?
|
||||||
|
# To prevent these just finished sequences from being added to the current sequences
|
||||||
|
# set of active beam search sequences, set their log probs to a very large
|
||||||
|
# negative value.
|
||||||
|
did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id
|
||||||
|
topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7)
|
||||||
|
|
||||||
|
# 5. Get running sequences scores for next
|
||||||
|
# Determine the top k beam indices (from top 2*k beams) from log probs
|
||||||
|
# and gather top k beams (from top 2*k beams).
|
||||||
|
next_topk_indices = jnp.flip(lax.top_k(topk_log_probs, k=num_beams)[1], axis=1)
|
||||||
|
next_running_sequences, next_running_scores = gather_beams(
|
||||||
|
[topk_sequences, topk_log_probs], next_topk_indices, batch_size, num_beams
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Process topk logits
|
||||||
|
# Further process log probs:
|
||||||
|
# - add length penalty
|
||||||
|
# - make sure no scores can be added anymore if beam is full
|
||||||
|
# - make sure still running sequences cannot be chosen as finalized beam
|
||||||
|
topk_log_probs = topk_log_probs / (state.cur_len ** length_penalty)
|
||||||
|
beams_in_batch_are_full = (
|
||||||
|
jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape)
|
||||||
|
& early_stopping
|
||||||
|
)
|
||||||
|
add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
|
||||||
|
topk_log_probs += add_penalty * np.array(-1.0e7)
|
||||||
|
|
||||||
|
# 7. Get scores, sequences, is sentence finished for next.
|
||||||
|
# Combine sequences, scores, and flags along the beam dimension and compare
|
||||||
|
# new finished sequence scores to existing finished scores and select the
|
||||||
|
# best from the new set of beams
|
||||||
|
merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1)
|
||||||
|
merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1)
|
||||||
|
merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1)
|
||||||
|
topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1)
|
||||||
|
next_sequences, next_scores, next_is_sent_finished = gather_beams(
|
||||||
|
[merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams
|
||||||
|
)
|
||||||
|
|
||||||
|
# 8. Update model kwargs.
|
||||||
|
# Determine the top k beam indices from the original set of all beams.
|
||||||
|
# With these, gather the top k beam-associated caches.
|
||||||
|
next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams)
|
||||||
|
next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams)
|
||||||
|
model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
|
||||||
|
next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
|
||||||
|
|
||||||
|
return BeamSearchState(
|
||||||
|
cur_len=state.cur_len + 1,
|
||||||
|
running_scores=next_running_scores,
|
||||||
|
running_sequences=next_running_sequences,
|
||||||
|
scores=next_scores,
|
||||||
|
sequences=next_sequences,
|
||||||
|
is_sent_finished=next_is_sent_finished,
|
||||||
|
model_kwargs=next_model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
|
||||||
|
state = beam_search_body_fn(state)
|
||||||
|
|
||||||
|
if not trace:
|
||||||
|
state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state)
|
||||||
|
else:
|
||||||
|
state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state)
|
||||||
|
|
||||||
|
# Account for the edge-case where there are no finished sequences for a
|
||||||
|
# particular batch item. If so, return running sequences for that batch item.
|
||||||
|
none_finished = jnp.any(state.is_sent_finished, axis=1)
|
||||||
|
sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences)
|
||||||
|
scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)
|
||||||
|
|
||||||
|
# take best beam for each batch
|
||||||
|
sequences = sequences[:, -1]
|
||||||
|
scores = scores[:, -1]
|
||||||
|
|
||||||
|
return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
|
||||||
|
|||||||
@@ -2,6 +2,24 @@
|
|||||||
from ..file_utils import requires_backends
|
from ..file_utils import requires_backends
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxForcedBOSTokenLogitsProcessor:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxForcedEOSTokenLogitsProcessor:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
class FlaxLogitsProcessor:
|
class FlaxLogitsProcessor:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["flax"])
|
requires_backends(self, ["flax"])
|
||||||
@@ -25,6 +43,15 @@ class FlaxLogitsWarper:
|
|||||||
requires_backends(self, ["flax"])
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxMinLengthLogitsProcessor:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
class FlaxTemperatureLogitsWarper:
|
class FlaxTemperatureLogitsWarper:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["flax"])
|
requires_backends(self, ["flax"])
|
||||||
|
|||||||
@@ -28,7 +28,10 @@ if is_flax_available():
|
|||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from transformers.generation_flax_logits_process import (
|
from transformers.generation_flax_logits_process import (
|
||||||
|
FlaxForcedBOSTokenLogitsProcessor,
|
||||||
|
FlaxForcedEOSTokenLogitsProcessor,
|
||||||
FlaxLogitsProcessorList,
|
FlaxLogitsProcessorList,
|
||||||
|
FlaxMinLengthLogitsProcessor,
|
||||||
FlaxTemperatureLogitsWarper,
|
FlaxTemperatureLogitsWarper,
|
||||||
FlaxTopKLogitsWarper,
|
FlaxTopKLogitsWarper,
|
||||||
FlaxTopPLogitsWarper,
|
FlaxTopPLogitsWarper,
|
||||||
@@ -57,8 +60,8 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
temp_dist_warper_sharper = FlaxTemperatureLogitsWarper(temperature=0.5)
|
temp_dist_warper_sharper = FlaxTemperatureLogitsWarper(temperature=0.5)
|
||||||
temp_dist_warper_smoother = FlaxTemperatureLogitsWarper(temperature=1.3)
|
temp_dist_warper_smoother = FlaxTemperatureLogitsWarper(temperature=1.3)
|
||||||
|
|
||||||
warped_prob_sharp = jax.nn.softmax(temp_dist_warper_sharper(input_ids, scores.copy()), axis=-1)
|
warped_prob_sharp = jax.nn.softmax(temp_dist_warper_sharper(input_ids, scores.copy(), cur_len=None), axis=-1)
|
||||||
warped_prob_smooth = jax.nn.softmax(temp_dist_warper_smoother(input_ids, scores.copy()), axis=-1)
|
warped_prob_smooth = jax.nn.softmax(temp_dist_warper_smoother(input_ids, scores.copy(), cur_len=None), axis=-1)
|
||||||
|
|
||||||
# uniform distribution stays uniform
|
# uniform distribution stays uniform
|
||||||
self.assertTrue(jnp.allclose(probs[0, :], warped_prob_sharp[0, :], atol=1e-3))
|
self.assertTrue(jnp.allclose(probs[0, :], warped_prob_sharp[0, :], atol=1e-3))
|
||||||
@@ -83,7 +86,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
|
|
||||||
top_k_warp = FlaxTopKLogitsWarper(3)
|
top_k_warp = FlaxTopKLogitsWarper(3)
|
||||||
|
|
||||||
scores = top_k_warp(input_ids, ramp_logits)
|
scores = top_k_warp(input_ids, ramp_logits, cur_len=None)
|
||||||
|
|
||||||
# check that correct tokens are filtered
|
# check that correct tokens are filtered
|
||||||
self.assertListEqual(jnp.isinf(scores[0]).tolist(), 7 * [True] + 3 * [False])
|
self.assertListEqual(jnp.isinf(scores[0]).tolist(), 7 * [True] + 3 * [False])
|
||||||
@@ -94,7 +97,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
top_k_warp_safety_check = FlaxTopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3)
|
top_k_warp_safety_check = FlaxTopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3)
|
||||||
|
|
||||||
ramp_logits = np.broadcast_to(np.arange(length)[None, :], (batch_size, length)).copy()
|
ramp_logits = np.broadcast_to(np.arange(length)[None, :], (batch_size, length)).copy()
|
||||||
scores = top_k_warp_safety_check(input_ids, ramp_logits)
|
scores = top_k_warp_safety_check(input_ids, ramp_logits, cur_len=None)
|
||||||
|
|
||||||
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
|
# min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified
|
||||||
self.assertListEqual((scores == 0.0).sum(axis=-1).tolist(), [2, 2])
|
self.assertListEqual((scores == 0.0).sum(axis=-1).tolist(), [2, 2])
|
||||||
@@ -108,7 +111,7 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]]))
|
dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]]))
|
||||||
|
|
||||||
top_p_warp = FlaxTopPLogitsWarper(0.7)
|
top_p_warp = FlaxTopPLogitsWarper(0.7)
|
||||||
filtered_dist = np.exp(top_p_warp(input_ids, dist))
|
filtered_dist = np.exp(top_p_warp(input_ids, dist, cur_len=None))
|
||||||
|
|
||||||
# dist should be filtered to keep min num values so that sum is >= 0.7
|
# dist should be filtered to keep min num values so that sum is >= 0.7
|
||||||
# exp (-inf) => 0
|
# exp (-inf) => 0
|
||||||
@@ -125,15 +128,81 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
|
|
||||||
# make sure at least 2 tokens are kept
|
# make sure at least 2 tokens are kept
|
||||||
top_p_warp = FlaxTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
|
top_p_warp = FlaxTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0)
|
||||||
filtered_dist = top_p_warp(input_ids, ramp_logits)
|
filtered_dist = top_p_warp(input_ids, ramp_logits, cur_len=None)
|
||||||
|
|
||||||
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
|
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
|
||||||
self.assertListEqual((filtered_dist != 0.0).sum(axis=-1).tolist(), [3, 2])
|
self.assertListEqual((filtered_dist != 0.0).sum(axis=-1).tolist(), [3, 2])
|
||||||
|
|
||||||
|
def test_min_length_dist_processor(self):
|
||||||
|
vocab_size = 20
|
||||||
|
batch_size = 4
|
||||||
|
eos_token_id = 0
|
||||||
|
|
||||||
|
min_dist_processor = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||||
|
|
||||||
|
# check that min length is applied at length 5
|
||||||
|
input_ids = ids_tensor((batch_size, 20), vocab_size=20)
|
||||||
|
cur_len = 5
|
||||||
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
|
scores_before_min_length = min_dist_processor(input_ids, scores, cur_len=cur_len)
|
||||||
|
self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), 4 * [-float("inf")])
|
||||||
|
|
||||||
|
# check that min length is not applied anymore at length 15
|
||||||
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
|
cur_len = 15
|
||||||
|
scores_before_min_length = min_dist_processor(input_ids, scores, cur_len=cur_len)
|
||||||
|
self.assertFalse(jnp.isinf(scores_before_min_length).any())
|
||||||
|
|
||||||
|
def test_forced_bos_token_logits_processor(self):
|
||||||
|
vocab_size = 20
|
||||||
|
batch_size = 4
|
||||||
|
bos_token_id = 0
|
||||||
|
|
||||||
|
logits_processor = FlaxForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
|
||||||
|
|
||||||
|
# check that all scores are -inf except the bos_token_id score
|
||||||
|
input_ids = ids_tensor((batch_size, 1), vocab_size=20)
|
||||||
|
cur_len = 1
|
||||||
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
|
scores = logits_processor(input_ids, scores, cur_len=cur_len)
|
||||||
|
self.assertTrue(jnp.isneginf(scores[:, bos_token_id + 1 :]).all())
|
||||||
|
self.assertListEqual(scores[:, bos_token_id].tolist(), 4 * [0]) # score for bos_token_id shold be zero
|
||||||
|
|
||||||
|
# check that bos_token_id is not forced if current length is greater than 1
|
||||||
|
cur_len = 3
|
||||||
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
|
scores = logits_processor(input_ids, scores, cur_len=cur_len)
|
||||||
|
self.assertFalse(jnp.isinf(scores).any())
|
||||||
|
|
||||||
|
def test_forced_eos_token_logits_processor(self):
|
||||||
|
vocab_size = 20
|
||||||
|
batch_size = 4
|
||||||
|
eos_token_id = 0
|
||||||
|
max_length = 5
|
||||||
|
|
||||||
|
logits_processor = FlaxForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
|
||||||
|
|
||||||
|
# check that all scores are -inf except the eos_token_id when max_length is reached
|
||||||
|
input_ids = ids_tensor((batch_size, 4), vocab_size=20)
|
||||||
|
cur_len = 4
|
||||||
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
|
scores = logits_processor(input_ids, scores, cur_len=cur_len)
|
||||||
|
self.assertTrue(jnp.isneginf(scores[:, eos_token_id + 1 :]).all())
|
||||||
|
self.assertListEqual(scores[:, eos_token_id].tolist(), 4 * [0]) # score for eos_token_id should be zero
|
||||||
|
|
||||||
|
# check that eos_token_id is not forced if max_length is not reached
|
||||||
|
cur_len = 3
|
||||||
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
|
scores = logits_processor(input_ids, scores, cur_len=cur_len)
|
||||||
|
self.assertFalse(jnp.isinf(scores).any())
|
||||||
|
|
||||||
def test_processor_list(self):
|
def test_processor_list(self):
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
sequence_length = 10
|
sequence_length = 10
|
||||||
vocab_size = 15
|
vocab_size = 15
|
||||||
|
eos_token_id = 2
|
||||||
|
bos_token_id = 1
|
||||||
|
max_length = 15
|
||||||
|
|
||||||
# dummy input_ids and scores
|
# dummy input_ids and scores
|
||||||
input_ids = ids_tensor((batch_size, sequence_length), vocab_size)
|
input_ids = ids_tensor((batch_size, sequence_length), vocab_size)
|
||||||
@@ -147,14 +216,83 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
top_k_warp = FlaxTopKLogitsWarper(3)
|
top_k_warp = FlaxTopKLogitsWarper(3)
|
||||||
top_p_warp = FlaxTopPLogitsWarper(0.8)
|
top_p_warp = FlaxTopPLogitsWarper(0.8)
|
||||||
|
|
||||||
|
# instantiate all logits processors
|
||||||
|
min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||||
|
bos_dist_proc = FlaxForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
|
||||||
|
eos_dist_proc = FlaxForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
|
||||||
|
|
||||||
|
cur_len = 10
|
||||||
|
|
||||||
# no processor list
|
# no processor list
|
||||||
scores = temp_dist_warp(input_ids, scores)
|
scores = temp_dist_warp(input_ids, scores, cur_len=cur_len)
|
||||||
scores = top_k_warp(input_ids, scores)
|
scores = top_k_warp(input_ids, scores, cur_len=cur_len)
|
||||||
scores = top_p_warp(input_ids, scores)
|
scores = top_p_warp(input_ids, scores, cur_len=cur_len)
|
||||||
|
scores = min_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||||
|
scores = bos_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||||
|
scores = eos_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||||
|
|
||||||
# with processor list
|
# with processor list
|
||||||
processor = FlaxLogitsProcessorList([temp_dist_warp, top_k_warp, top_p_warp])
|
processor = FlaxLogitsProcessorList(
|
||||||
scores_comp = processor(input_ids, scores_comp)
|
[temp_dist_warp, top_k_warp, top_p_warp, min_dist_proc, bos_dist_proc, eos_dist_proc]
|
||||||
|
)
|
||||||
|
scores_comp = processor(input_ids, scores_comp, cur_len=cur_len)
|
||||||
|
|
||||||
|
# scores should be equal
|
||||||
|
self.assertTrue(jnp.allclose(scores, scores_comp, atol=1e-3))
|
||||||
|
|
||||||
|
# input_ids should never be changed
|
||||||
|
self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist())
|
||||||
|
|
||||||
|
def test_processor_list_jitted(self):
|
||||||
|
batch_size = 4
|
||||||
|
sequence_length = 10
|
||||||
|
vocab_size = 15
|
||||||
|
eos_token_id = 2
|
||||||
|
bos_token_id = 1
|
||||||
|
max_length = 15
|
||||||
|
|
||||||
|
# dummy input_ids and scores
|
||||||
|
input_ids = ids_tensor((batch_size, sequence_length), vocab_size)
|
||||||
|
input_ids_comp = input_ids.copy()
|
||||||
|
|
||||||
|
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||||
|
scores_comp = scores.copy()
|
||||||
|
|
||||||
|
# instantiate all dist processors
|
||||||
|
temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5)
|
||||||
|
top_k_warp = FlaxTopKLogitsWarper(3)
|
||||||
|
top_p_warp = FlaxTopPLogitsWarper(0.8)
|
||||||
|
|
||||||
|
# instantiate all logits processors
|
||||||
|
min_dist_proc = FlaxMinLengthLogitsProcessor(min_length=10, eos_token_id=eos_token_id)
|
||||||
|
bos_dist_proc = FlaxForcedBOSTokenLogitsProcessor(bos_token_id=bos_token_id)
|
||||||
|
eos_dist_proc = FlaxForcedEOSTokenLogitsProcessor(max_length=max_length, eos_token_id=eos_token_id)
|
||||||
|
|
||||||
|
cur_len = 10
|
||||||
|
|
||||||
|
# no processor list
|
||||||
|
def run_no_processor_list(input_ids, scores, cur_len):
|
||||||
|
scores = temp_dist_warp(input_ids, scores, cur_len=cur_len)
|
||||||
|
scores = top_k_warp(input_ids, scores, cur_len=cur_len)
|
||||||
|
scores = top_p_warp(input_ids, scores, cur_len=cur_len)
|
||||||
|
scores = min_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||||
|
scores = bos_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||||
|
scores = eos_dist_proc(input_ids, scores, cur_len=cur_len)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
# with processor list
|
||||||
|
def run_processor_list(input_ids, scores, cur_len):
|
||||||
|
processor = FlaxLogitsProcessorList(
|
||||||
|
[temp_dist_warp, top_k_warp, top_p_warp, min_dist_proc, bos_dist_proc, eos_dist_proc]
|
||||||
|
)
|
||||||
|
scores = processor(input_ids, scores, cur_len=cur_len)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
jitted_run_no_processor_list = jax.jit(run_no_processor_list)
|
||||||
|
jitted_run_processor_list = jax.jit(run_processor_list)
|
||||||
|
|
||||||
|
scores = jitted_run_no_processor_list(input_ids, scores, cur_len)
|
||||||
|
scores_comp = jitted_run_processor_list(input_ids, scores_comp, cur_len)
|
||||||
|
|
||||||
# scores should be equal
|
# scores should be equal
|
||||||
self.assertTrue(jnp.allclose(scores, scores_comp, atol=1e-3))
|
self.assertTrue(jnp.allclose(scores, scores_comp, atol=1e-3))
|
||||||
|
|||||||
@@ -110,6 +110,23 @@ class FlaxGenerationTesterMixin:
|
|||||||
|
|
||||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||||
|
|
||||||
|
def test_beam_search_generate(self):
|
||||||
|
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||||
|
config.do_sample = False
|
||||||
|
config.max_length = max_length
|
||||||
|
config.num_beams = 2
|
||||||
|
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
generation_outputs = model.generate(input_ids).sequences
|
||||||
|
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||||
|
|
||||||
|
jit_generate = jit(model.generate)
|
||||||
|
jit_generation_outputs = jit_generate(input_ids).sequences
|
||||||
|
|
||||||
|
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||||
|
|
||||||
def test_sample_generate_logits_warper(self):
|
def test_sample_generate_logits_warper(self):
|
||||||
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||||
config.do_sample = True
|
config.do_sample = True
|
||||||
@@ -117,6 +134,46 @@ class FlaxGenerationTesterMixin:
|
|||||||
config.temperature = 0.8
|
config.temperature = 0.8
|
||||||
config.top_k = 10
|
config.top_k = 10
|
||||||
config.top_p = 0.3
|
config.top_p = 0.3
|
||||||
|
config.min_length = 1
|
||||||
|
config.forced_bos_token_id = 8
|
||||||
|
config.forced_eos_token_id = 9
|
||||||
|
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
generation_outputs = model.generate(input_ids).sequences
|
||||||
|
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||||
|
|
||||||
|
jit_generate = jit(model.generate)
|
||||||
|
jit_generation_outputs = jit_generate(input_ids).sequences
|
||||||
|
|
||||||
|
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||||
|
|
||||||
|
def test_greedy_generate_logits_warper(self):
|
||||||
|
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||||
|
config.max_length = max_length
|
||||||
|
config.min_length = 1
|
||||||
|
config.forced_bos_token_id = 8
|
||||||
|
config.forced_eos_token_id = 9
|
||||||
|
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
generation_outputs = model.generate(input_ids).sequences
|
||||||
|
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||||
|
|
||||||
|
jit_generate = jit(model.generate)
|
||||||
|
jit_generation_outputs = jit_generate(input_ids).sequences
|
||||||
|
|
||||||
|
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||||
|
|
||||||
|
def test_beam_search_generate_logits_warper(self):
|
||||||
|
config, input_ids, _, max_length = self._get_input_ids_and_config()
|
||||||
|
config.max_length = max_length
|
||||||
|
config.num_beams = 2
|
||||||
|
config.min_length = 1
|
||||||
|
config.forced_bos_token_id = 8
|
||||||
|
config.forced_eos_token_id = 9
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
@@ -168,3 +225,23 @@ class FlaxGenerationTesterMixin:
|
|||||||
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
|
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
|
||||||
|
|
||||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||||
|
|
||||||
|
def test_beam_search_generate_attn_mask(self):
|
||||||
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||||
|
|
||||||
|
# pad attention mask on the left
|
||||||
|
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
|
||||||
|
|
||||||
|
config.num_beams = 2
|
||||||
|
config.max_length = max_length
|
||||||
|
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
generation_outputs = model.generate(input_ids, attention_mask=attention_mask).sequences
|
||||||
|
self.assertEqual(generation_outputs.shape[-1], max_length)
|
||||||
|
|
||||||
|
jit_generate = jit(model.generate)
|
||||||
|
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
|
||||||
|
|
||||||
|
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user