[Flax] Allow dataclasses to be jitted (#11886)

* fix_torch_device_generate_test

* remove @

* change dataclasses to flax ones

* fix typo

* fix jitted tests

* fix bert & electra
This commit is contained in:
Patrick von Platen
2021-05-26 15:01:13 +01:00
committed by GitHub
parent e6126e1932
commit d5a72b6e19
4 changed files with 17 additions and 30 deletions

View File

@@ -248,31 +248,19 @@ class FlaxModelTesterMixin:
@jax.jit
def model_jitted(input_ids, attention_mask=None, **kwargs):
return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs).to_tuple()
return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
with self.subTest("JIT Enabled"):
jitted_outputs = model_jitted(**prepared_inputs_dict)
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = model_jitted(**prepared_inputs_dict)
outputs = model_jitted(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
@jax.jit
def model_jitted_return_dict(input_ids, attention_mask=None, **kwargs):
return model(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs,
)
# jitted function cannot return OrderedDict
with self.assertRaises(TypeError):
model_jitted_return_dict(**prepared_inputs_dict)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()