Generate: FLAX infers pad token in its absence and has functional example (#21009)
This commit is contained in:
@@ -305,10 +305,10 @@ class FlaxGenerationMixin:
|
|||||||
>>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2")
|
>>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2")
|
||||||
>>> input_context = "The dog"
|
>>> input_context = "The dog"
|
||||||
>>> # encode input context
|
>>> # encode input context
|
||||||
>>> input_ids = tokenizer(input_context, return_tensors="np").input_ids
|
>>> inputs = tokenizer(input_context, return_tensors="np")
|
||||||
>>> # generate candidates using sampling
|
>>> # generate candidates using sampling
|
||||||
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
|
>>> outputs = model.generate(**inputs, max_length=20, top_k=30, do_sample=True)
|
||||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
>>> tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
|
||||||
```"""
|
```"""
|
||||||
# Validate the `.generate()` call
|
# Validate the `.generate()` call
|
||||||
self._validate_model_class()
|
self._validate_model_class()
|
||||||
@@ -323,6 +323,17 @@ class FlaxGenerationMixin:
|
|||||||
)
|
)
|
||||||
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
|
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
|
||||||
|
|
||||||
|
if pad_token_id is None and eos_token_id is not None:
|
||||||
|
if model_kwargs.get("attention_mask") is None:
|
||||||
|
logger.warning(
|
||||||
|
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||||
|
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
||||||
|
)
|
||||||
|
if isinstance(eos_token_id, list):
|
||||||
|
eos_token_id = eos_token_id[0]
|
||||||
|
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||||
|
pad_token_id = eos_token_id
|
||||||
|
|
||||||
if decoder_start_token_id is None and self.config.is_encoder_decoder:
|
if decoder_start_token_id is None and self.config.is_encoder_decoder:
|
||||||
raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")
|
raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")
|
||||||
|
|
||||||
@@ -525,8 +536,8 @@ class FlaxGenerationMixin:
|
|||||||
|
|
||||||
batch_size, cur_len = input_ids.shape
|
batch_size, cur_len = input_ids.shape
|
||||||
|
|
||||||
eos_token_id = jnp.array(eos_token_id)
|
eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None)
|
||||||
pad_token_id = jnp.array(pad_token_id)
|
pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32)
|
||||||
cur_len = jnp.array(cur_len)
|
cur_len = jnp.array(cur_len)
|
||||||
|
|
||||||
# per batch-item holding current token in loop.
|
# per batch-item holding current token in loop.
|
||||||
@@ -614,8 +625,8 @@ class FlaxGenerationMixin:
|
|||||||
|
|
||||||
batch_size, cur_len = input_ids.shape
|
batch_size, cur_len = input_ids.shape
|
||||||
|
|
||||||
eos_token_id = jnp.array(eos_token_id)
|
eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None)
|
||||||
pad_token_id = jnp.array(pad_token_id)
|
pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32)
|
||||||
cur_len = jnp.array(cur_len)
|
cur_len = jnp.array(cur_len)
|
||||||
|
|
||||||
# per batch-item holding current token in loop.
|
# per batch-item holding current token in loop.
|
||||||
@@ -748,8 +759,8 @@ class FlaxGenerationMixin:
|
|||||||
|
|
||||||
batch_size, num_beams, cur_len = input_ids.shape
|
batch_size, num_beams, cur_len = input_ids.shape
|
||||||
|
|
||||||
eos_token_id = jnp.array(eos_token_id)
|
eos_token_id = jnp.array(eos_token_id, dtype=jnp.int32 if eos_token_id is not None else None)
|
||||||
pad_token_id = jnp.array(pad_token_id)
|
pad_token_id = jnp.array(pad_token_id, dtype=jnp.int32)
|
||||||
cur_len = jnp.array(cur_len)
|
cur_len = jnp.array(cur_len)
|
||||||
|
|
||||||
# per batch,beam-item holding current token in loop.
|
# per batch,beam-item holding current token in loop.
|
||||||
|
|||||||
@@ -702,11 +702,11 @@ class TFGenerationMixin:
|
|||||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||||
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
||||||
)
|
)
|
||||||
logger.warning(
|
eos_token_id = generation_config.eos_token_id
|
||||||
f"Setting `pad_token_id` to {generation_config.eos_token_id} (first `eos_token_id`) to generate"
|
if isinstance(eos_token_id, list):
|
||||||
" sequence"
|
eos_token_id = eos_token_id[0]
|
||||||
)
|
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||||
generation_config.pad_token_id = generation_config.eos_token_id
|
generation_config.pad_token_id = eos_token_id
|
||||||
|
|
||||||
use_xla = not tf.executing_eagerly()
|
use_xla = not tf.executing_eagerly()
|
||||||
if use_xla and not self.supports_xla_generation:
|
if use_xla and not self.supports_xla_generation:
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ docs/source/en/model_doc/tapex.mdx
|
|||||||
docs/source/en/model_doc/donut.mdx
|
docs/source/en/model_doc/donut.mdx
|
||||||
docs/source/en/model_doc/encoder-decoder.mdx
|
docs/source/en/model_doc/encoder-decoder.mdx
|
||||||
src/transformers/generation/configuration_utils.py
|
src/transformers/generation/configuration_utils.py
|
||||||
|
src/transformers/generation/flax_utils.py
|
||||||
src/transformers/generation/tf_utils.py
|
src/transformers/generation/tf_utils.py
|
||||||
src/transformers/generation/utils.py
|
src/transformers/generation/utils.py
|
||||||
src/transformers/models/albert/configuration_albert.py
|
src/transformers/models/albert/configuration_albert.py
|
||||||
|
|||||||
Reference in New Issue
Block a user