Generate: validate model_kwargs on FLAX (and catch typos in generate arguments) (#18653)
This commit is contained in:
@@ -15,9 +15,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import inspect
|
||||||
import warnings
|
import warnings
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -160,6 +161,24 @@ class FlaxGenerationMixin:
|
|||||||
"""
|
"""
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
|
||||||
|
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
|
||||||
|
unused_model_args = []
|
||||||
|
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
|
||||||
|
# `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If
|
||||||
|
# `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;)
|
||||||
|
if "kwargs" in model_args:
|
||||||
|
model_args |= set(inspect.signature(self.__call__).parameters)
|
||||||
|
for key, value in model_kwargs.items():
|
||||||
|
if value is not None and key not in model_args:
|
||||||
|
unused_model_args.append(key)
|
||||||
|
|
||||||
|
if unused_model_args:
|
||||||
|
raise ValueError(
|
||||||
|
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
|
||||||
|
" generate arguments will also show up in this list)"
|
||||||
|
)
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
input_ids: jnp.ndarray,
|
input_ids: jnp.ndarray,
|
||||||
@@ -262,6 +281,9 @@ class FlaxGenerationMixin:
|
|||||||
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
|
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
|
||||||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
```"""
|
```"""
|
||||||
|
# Validate model kwargs
|
||||||
|
self._validate_model_kwargs(model_kwargs.copy())
|
||||||
|
|
||||||
# set init values
|
# set init values
|
||||||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import random
|
import random
|
||||||
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -26,6 +27,7 @@ if is_flax_available():
|
|||||||
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from jax import jit
|
from jax import jit
|
||||||
|
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
|
||||||
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
|
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
|
||||||
|
|
||||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||||
@@ -273,3 +275,22 @@ class FlaxGenerationTesterMixin:
|
|||||||
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
|
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
|
||||||
|
|
||||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||||
|
|
||||||
|
|
||||||
|
@require_flax
|
||||||
|
class FlaxGenerationIntegrationTests(unittest.TestCase):
|
||||||
|
def test_validate_generation_inputs(self):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-bert")
|
||||||
|
model = FlaxAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||||
|
|
||||||
|
encoder_input_str = "Hello world"
|
||||||
|
input_ids = tokenizer(encoder_input_str, return_tensors="np").input_ids
|
||||||
|
|
||||||
|
# typos are quickly detected (the correct argument is `do_sample`)
|
||||||
|
with self.assertRaisesRegex(ValueError, "do_samples"):
|
||||||
|
model.generate(input_ids, do_samples=True)
|
||||||
|
|
||||||
|
# arbitrary arguments that will not be used anywhere are also not accepted
|
||||||
|
with self.assertRaisesRegex(ValueError, "foo"):
|
||||||
|
fake_model_kwargs = {"foo": "bar"}
|
||||||
|
model.generate(input_ids, **fake_model_kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user