Use 5e-5 For BigBird PT/Flax equivalence tests (#17780)

* rename to check_pt_flax_outputs

* update check_pt_flax_outputs

* use 5e-5 for BigBird PT/Flax test

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2022-06-21 17:55:26 +02:00
committed by GitHub
parent 6a5272b205
commit f47afefb21
4 changed files with 153 additions and 48 deletions

View File

@@ -597,13 +597,14 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase):
self.model_tester.create_and_check_for_change_to_full_attn(*config_and_inputs)
# overwrite from common in order to skip the check on `attentions`
def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
# also use `5e-5` to avoid flaky test failure
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version,
# an effort was done to return `attention_probs` (yet to be verified).
if type(names) == str and names.startswith("attentions"):
if name.startswith("outputs.attentions"):
return
else:
super().check_outputs(fx_outputs, pt_outputs, model_class, names)
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
@require_torch

View File

@@ -214,10 +214,11 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
self.assertEqual(jitted_output.shape, output.shape)
# overwrite from common in order to skip the check on `attentions`
def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
# also use `5e-5` to avoid flaky test failure
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version,
# an effort was done to return `attention_probs` (yet to be verified).
if type(names) == str and names.startswith("attentions"):
if name.startswith("outputs.attentions"):
return
else:
super().check_outputs(fx_outputs, pt_outputs, model_class, names)
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)