Generate: validate model_kwargs on FLAX (and catch typos in generate arguments) (#18653)
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -26,6 +27,7 @@ if is_flax_available():
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import jit
|
||||
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
|
||||
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
|
||||
|
||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||
@@ -273,3 +275,22 @@ class FlaxGenerationTesterMixin:
|
||||
jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
|
||||
|
||||
self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxGenerationIntegrationTests(unittest.TestCase):
|
||||
def test_validate_generation_inputs(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-bert")
|
||||
model = FlaxAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
|
||||
encoder_input_str = "Hello world"
|
||||
input_ids = tokenizer(encoder_input_str, return_tensors="np").input_ids
|
||||
|
||||
# typos are quickly detected (the correct argument is `do_sample`)
|
||||
with self.assertRaisesRegex(ValueError, "do_samples"):
|
||||
model.generate(input_ids, do_samples=True)
|
||||
|
||||
# arbitrary arguments that will not be used anywhere are also not accepted
|
||||
with self.assertRaisesRegex(ValueError, "foo"):
|
||||
fake_model_kwargs = {"foo": "bar"}
|
||||
model.generate(input_ids, **fake_model_kwargs)
|
||||
|
||||
Reference in New Issue
Block a user