Use higher value for hidden_size in Flax BigBird test (#17822)
* Use higher value for hidden_size in Flax BigBird test * remove 5e-5 Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -597,8 +597,7 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
self.model_tester.create_and_check_for_change_to_full_attn(*config_and_inputs)
|
self.model_tester.create_and_check_for_change_to_full_attn(*config_and_inputs)
|
||||||
|
|
||||||
# overwrite from common in order to skip the check on `attentions`
|
# overwrite from common in order to skip the check on `attentions`
|
||||||
# also use `5e-5` to avoid flaky test failure
|
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
|
||||||
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
|
|
||||||
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version,
|
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version,
|
||||||
# an effort was done to return `attention_probs` (yet to be verified).
|
# an effort was done to return `attention_probs` (yet to be verified).
|
||||||
if name.startswith("outputs.attentions"):
|
if name.startswith("outputs.attentions"):
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ class FlaxBigBirdModelTester(unittest.TestCase):
|
|||||||
use_token_type_ids=True,
|
use_token_type_ids=True,
|
||||||
use_labels=True,
|
use_labels=True,
|
||||||
vocab_size=99,
|
vocab_size=99,
|
||||||
hidden_size=4,
|
hidden_size=32,
|
||||||
num_hidden_layers=2,
|
num_hidden_layers=2,
|
||||||
num_attention_heads=2,
|
num_attention_heads=2,
|
||||||
intermediate_size=7,
|
intermediate_size=7,
|
||||||
@@ -214,8 +214,7 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
|||||||
self.assertEqual(jitted_output.shape, output.shape)
|
self.assertEqual(jitted_output.shape, output.shape)
|
||||||
|
|
||||||
# overwrite from common in order to skip the check on `attentions`
|
# overwrite from common in order to skip the check on `attentions`
|
||||||
# also use `5e-5` to avoid flaky test failure
|
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
|
||||||
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=5e-5, name="outputs", attributes=None):
|
|
||||||
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version,
|
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version,
|
||||||
# an effort was done to return `attention_probs` (yet to be verified).
|
# an effort was done to return `attention_probs` (yet to be verified).
|
||||||
if name.startswith("outputs.attentions"):
|
if name.startswith("outputs.attentions"):
|
||||||
|
|||||||
Reference in New Issue
Block a user