From 0e0f1f4692b9dbbab56b1adf32e0911caeecaa34 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 24 Jun 2022 19:31:30 +0200 Subject: [PATCH] Use higher value for hidden_size in Flax BigBird test (#17822) * Use higher value for hidden_size in Flax BigBird test * remove 5e-5 Co-authored-by: ydshieh --- tests/models/big_bird/test_modeling_big_bird.py | 3 +-- tests/models/big_bird/test_modeling_flax_big_bird.py | 5 ++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/models/big_bird/test_modeling_big_bird.py b/tests/models/big_bird/test_modeling_big_bird.py index f77af8049c..ec59f8f93d 100644 --- a/tests/models/big_bird/test_modeling_big_bird.py +++ b/tests/models/big_bird/test_modeling_big_bird.py @@ -597,8 +597,7 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase): self.model_tester.create_and_check_for_change_to_full_attn(*config_and_inputs) # overwrite from common in order to skip the check on `attentions` - # also use `5e-5` to avoid flaky test failure - def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None): + def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None): # `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version, # an effort was done to return `attention_probs` (yet to be verified). if name.startswith("outputs.attentions"): diff --git a/tests/models/big_bird/test_modeling_flax_big_bird.py b/tests/models/big_bird/test_modeling_flax_big_bird.py index f8659dad76..7c4c726721 100644 --- a/tests/models/big_bird/test_modeling_flax_big_bird.py +++ b/tests/models/big_bird/test_modeling_flax_big_bird.py @@ -47,7 +47,7 @@ class FlaxBigBirdModelTester(unittest.TestCase): use_token_type_ids=True, use_labels=True, vocab_size=99, - hidden_size=4, + hidden_size=32, num_hidden_layers=2, num_attention_heads=2, intermediate_size=7, @@ -214,8 +214,7 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase): self.assertEqual(jitted_output.shape, output.shape) # overwrite from common in order to skip the check on `attentions` - # also use `5e-5` to avoid flaky test failure - def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None): + def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None): # `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version, # an effort was done to return `attention_probs` (yet to be verified). if name.startswith("outputs.attentions"):