Use 5e-5 For BigBird PT/Flax equivalence tests (#17780)
* rename to check_pt_flax_outputs * update check_pt_flax_outputs * use 5e-5 for BigBird PT/Flax test Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -597,13 +597,14 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
self.model_tester.create_and_check_for_change_to_full_attn(*config_and_inputs)
|
||||
|
||||
# overwrite from common in order to skip the check on `attentions`
|
||||
def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
|
||||
# also use `5e-5` to avoid flaky test failure
|
||||
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
|
||||
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version,
|
||||
# an effort was done to return `attention_probs` (yet to be verified).
|
||||
if type(names) == str and names.startswith("attentions"):
|
||||
if name.startswith("outputs.attentions"):
|
||||
return
|
||||
else:
|
||||
super().check_outputs(fx_outputs, pt_outputs, model_class, names)
|
||||
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||
|
||||
|
||||
@require_torch
|
||||
|
||||
@@ -214,10 +214,11 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
self.assertEqual(jitted_output.shape, output.shape)
|
||||
|
||||
# overwrite from common in order to skip the check on `attentions`
|
||||
def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
|
||||
# also use `5e-5` to avoid flaky test failure
|
||||
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
|
||||
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version,
|
||||
# an effort was done to return `attention_probs` (yet to be verified).
|
||||
if type(names) == str and names.startswith("attentions"):
|
||||
if name.startswith("outputs.attentions"):
|
||||
return
|
||||
else:
|
||||
super().check_outputs(fx_outputs, pt_outputs, model_class, names)
|
||||
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||
|
||||
Reference in New Issue
Block a user