[Flax] Adapt Flax models to new structure (#9484)
* Create modeling_flax_eletra with code copied from modeling_flax_bert * Add ElectraForMaskedLM and ElectraForPretraining * Add modeling test for Flax electra and fix naming and arg in Flax Electra model * Add documentation * Fix code style * Create modeling_flax_eletra with code copied from modeling_flax_bert * Add ElectraForMaskedLM and ElectraForPretraining * Add modeling test for Flax electra and fix naming and arg in Flax Electra model * Add documentation * Fix code style * Fix code quality * Adjust tol in assert_almost_equal due to very small difference between model output, ranging 0.0010 - 0.0016 * Remove redundant ElectraPooler * save intermediate * adapt * correct bert flax design * adapt roberta as well * finish roberta flax * finish * apply suggestions * apply suggestions Co-authored-by: Chris Nguyen <anhtu2687@gmail.com>
This commit is contained in:
committed by
GitHub
parent
5c0bf39782
commit
0b98ca368f
@@ -60,6 +60,7 @@ def random_attention_mask(shape, rng=None):
|
||||
return attn_mask
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxModelTesterMixin:
|
||||
model_tester = None
|
||||
all_model_classes = ()
|
||||
@@ -90,7 +91,7 @@ class FlaxModelTesterMixin:
|
||||
fx_outputs = fx_model(**inputs_dict)
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 2e-3)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
@@ -103,7 +104,6 @@ class FlaxModelTesterMixin:
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 5e-3)
|
||||
|
||||
@require_flax
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -121,7 +121,6 @@ class FlaxModelTesterMixin:
|
||||
for output_loaded, output in zip(outputs_loaded, outputs):
|
||||
self.assert_almost_equals(output_loaded, output, 5e-3)
|
||||
|
||||
@require_flax
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -144,7 +143,6 @@ class FlaxModelTesterMixin:
|
||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
@require_flax
|
||||
def test_naming_convention(self):
|
||||
for model_class in self.all_model_classes:
|
||||
model_class_name = model_class.__name__
|
||||
|
||||
Reference in New Issue
Block a user