Generate: FLAX infers pad token in its absence and has functional example (#21009)
This commit is contained in:
@@ -305,10 +305,10 @@ class FlaxGenerationMixin:
|
||||
>>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2")
|
||||
>>> input_context = "The dog"
|
||||
>>> # encode input context
|
||||
>>> input_ids = tokenizer(input_context, return_tensors="np").input_ids
|
||||
>>> inputs = tokenizer(input_context, return_tensors="np")
|
||||
>>> # generate candidates using sampling
|
||||
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
|
||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
>>> outputs = model.generate(**inputs, max_length=20, top_k=30, do_sample=True)
|
||||
>>> tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
|
||||
```"""
|
||||
# Validate the `.generate()` call
|
||||
self._validate_model_class()
|
||||
@@ -323,6 +323,17 @@ class FlaxGenerationMixin:
|
||||
)
|
||||
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
|
||||
|
||||
if pad_token_id is None and eos_token_id is not None:
|
||||
if model_kwargs.get("attention_mask") is None:
|
||||
logger.warning(
|
||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
||||
)
|
||||
if isinstance(eos_token_id, list):
|
||||
eos_token_id = eos_token_id[0]
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||
pad_token_id = eos_token_id
|
||||
|
||||
if decoder_start_token_id is None and self.config.is_encoder_decoder:
|
||||
raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")
|
||||
|
||||
@@ -525,8 +536,8 @@ class FlaxGenerationMixin:
|
||||
|
||||
batch_size, cur_len = input_ids.shape
|
||||
|
||||
eos_token_id = jnp.array(eos_token_id)
|
||||
pad_token_id = jnp.array(pad_token_id)
|
||||
eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None)
|
||||
pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32)
|
||||
cur_len = jnp.array(cur_len)
|
||||
|
||||
# per batch-item holding current token in loop.
|
||||
@@ -614,8 +625,8 @@ class FlaxGenerationMixin:
|
||||
|
||||
batch_size, cur_len = input_ids.shape
|
||||
|
||||
eos_token_id = jnp.array(eos_token_id)
|
||||
pad_token_id = jnp.array(pad_token_id)
|
||||
eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None)
|
||||
pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32)
|
||||
cur_len = jnp.array(cur_len)
|
||||
|
||||
# per batch-item holding current token in loop.
|
||||
@@ -748,8 +759,8 @@ class FlaxGenerationMixin:
|
||||
|
||||
batch_size, num_beams, cur_len = input_ids.shape
|
||||
|
||||
eos_token_id = jnp.array(eos_token_id)
|
||||
pad_token_id = jnp.array(pad_token_id)
|
||||
eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None)
|
||||
pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32)
|
||||
cur_len = jnp.array(cur_len)
|
||||
|
||||
# per batch,beam-item holding current token in loop.
|
||||
|
||||
@@ -702,11 +702,11 @@ class TFGenerationMixin:
|
||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
||||
)
|
||||
logger.warning(
|
||||
f"Setting `pad_token_id` to {generation_config.eos_token_id} (first `eos_token_id`) to generate"
|
||||
" sequence"
|
||||
)
|
||||
generation_config.pad_token_id = generation_config.eos_token_id
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
if isinstance(eos_token_id, list):
|
||||
eos_token_id = eos_token_id[0]
|
||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
|
||||
use_xla = not tf.executing_eagerly()
|
||||
if use_xla and not self.supports_xla_generation:
|
||||
|
||||
@@ -13,6 +13,7 @@ docs/source/en/model_doc/tapex.mdx
|
||||
docs/source/en/model_doc/donut.mdx
|
||||
docs/source/en/model_doc/encoder-decoder.mdx
|
||||
src/transformers/generation/configuration_utils.py
|
||||
src/transformers/generation/flax_utils.py
|
||||
src/transformers/generation/tf_utils.py
|
||||
src/transformers/generation/utils.py
|
||||
src/transformers/models/albert/configuration_albert.py
|
||||
|
||||
Reference in New Issue
Block a user