[Flax] Allow dataclasses to be jitted (#11886)
* fix_torch_device_generate_test * remove @ * change dataclasses to flax ones * fix typo * fix jitted tests * fix bert & electra
This commit is contained in:
committed by
GitHub
parent
e6126e1932
commit
d5a72b6e19
@@ -248,31 +248,19 @@ class FlaxModelTesterMixin:
|
||||
|
||||
@jax.jit
|
||||
def model_jitted(input_ids, attention_mask=None, **kwargs):
|
||||
return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs).to_tuple()
|
||||
return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
|
||||
|
||||
with self.subTest("JIT Enabled"):
|
||||
jitted_outputs = model_jitted(**prepared_inputs_dict)
|
||||
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
||||
|
||||
with self.subTest("JIT Disabled"):
|
||||
with jax.disable_jit():
|
||||
outputs = model_jitted(**prepared_inputs_dict)
|
||||
outputs = model_jitted(**prepared_inputs_dict).to_tuple()
|
||||
|
||||
self.assertEqual(len(outputs), len(jitted_outputs))
|
||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
@jax.jit
|
||||
def model_jitted_return_dict(input_ids, attention_mask=None, **kwargs):
|
||||
return model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# jitted function cannot return OrderedDict
|
||||
with self.assertRaises(TypeError):
|
||||
model_jitted_return_dict(**prepared_inputs_dict)
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user