[Flax] Adapt Flax models to new structure (#9484)

* Create modeling_flax_eletra with code copied from modeling_flax_bert

* Add ElectraForMaskedLM and ElectraForPretraining

* Add modeling test for Flax electra and fix naming and arg in Flax Electra model

* Add documentation

* Fix code style

* Create modeling_flax_eletra with code copied from modeling_flax_bert

* Add ElectraForMaskedLM and ElectraForPretraining

* Add modeling test for Flax electra and fix naming and arg in Flax Electra model

* Add documentation

* Fix code style

* Fix code quality

* Adjust tol in assert_almost_equal due to very small difference between model output, ranging 0.0010 - 0.0016

* Remove redundant ElectraPooler

* save intermediate

* adapt

* correct bert flax design

* adapt roberta as well

* finish roberta flax

* finish

* apply suggestions

* apply suggestions

Co-authored-by: Chris Nguyen <anhtu2687@gmail.com>
This commit is contained in:
Patrick von Platen
2021-03-18 09:44:17 +03:00
committed by GitHub
parent 5c0bf39782
commit 0b98ca368f
3 changed files with 316 additions and 500 deletions

View File

@@ -60,6 +60,7 @@ def random_attention_mask(shape, rng=None):
return attn_mask
@require_flax
class FlaxModelTesterMixin:
model_tester = None
all_model_classes = ()
@@ -90,7 +91,7 @@ class FlaxModelTesterMixin:
fx_outputs = fx_model(**inputs_dict)
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
self.assert_almost_equals(fx_output, pt_output.numpy(), 2e-3)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
@@ -103,7 +104,6 @@ class FlaxModelTesterMixin:
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3)
@require_flax
def test_from_pretrained_save_pretrained(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -121,7 +121,6 @@ class FlaxModelTesterMixin:
for output_loaded, output in zip(outputs_loaded, outputs):
self.assert_almost_equals(output_loaded, output, 5e-3)
@require_flax
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -144,7 +143,6 @@ class FlaxModelTesterMixin:
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
@require_flax
def test_naming_convention(self):
for model_class in self.all_model_classes:
model_class_name = model_class.__name__