Generate: validate model_kwargs on FLAX (and catch typos in generate arguments) (#18653)

This commit is contained in:
Joao Gante
2022-08-18 10:56:21 +01:00
committed by GitHub
parent 0ea53822f8
commit a541d97477
2 changed files with 44 additions and 1 deletions

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import random
import unittest
import numpy as np
@@ -26,6 +27,7 @@ if is_flax_available():
import jax.numpy as jnp
from jax import jit
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
@@ -273,3 +275,22 @@ class FlaxGenerationTesterMixin:
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
@require_flax
class FlaxGenerationIntegrationTests(unittest.TestCase):
def test_validate_generation_inputs(self):
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-bert")
model = FlaxAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
encoder_input_str = "Hello world"
input_ids = tokenizer(encoder_input_str, return_tensors="np").input_ids
# typos are quickly detected (the correct argument is `do_sample`)
with self.assertRaisesRegex(ValueError, "do_samples"):
model.generate(input_ids, do_samples=True)
# arbitrary arguments that will not be used anywhere are also not accepted
with self.assertRaisesRegex(ValueError, "foo"):
fake_model_kwargs = {"foo": "bar"}
model.generate(input_ids, **fake_model_kwargs)