From 5e428b71b4df5592fc9e9cb8f6840e1a1e7fca77 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 10 Jun 2022 16:54:14 +0200 Subject: [PATCH] [BigBirdFlaxTests] Make tests slow (#17658) * [BigBirdFlaxTests] Make tests slow * up * correct black with new version --- .../big_bird/test_modeling_flax_big_bird.py | 34 +++++++++++++++---- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/tests/models/big_bird/test_modeling_flax_big_bird.py b/tests/models/big_bird/test_modeling_flax_big_bird.py index 5c5452441e..3a07996e7a 100644 --- a/tests/models/big_bird/test_modeling_flax_big_bird.py +++ b/tests/models/big_bird/test_modeling_flax_big_bird.py @@ -40,17 +40,17 @@ class FlaxBigBirdModelTester(unittest.TestCase): def __init__( self, parent, - batch_size=13, + batch_size=2, seq_length=56, is_training=True, use_attention_mask=True, use_token_type_ids=True, use_labels=True, vocab_size=99, - hidden_size=32, - num_hidden_layers=5, - num_attention_heads=4, - intermediate_size=37, + hidden_size=4, + num_hidden_layers=2, + num_attention_heads=2, + intermediate_size=7, hidden_act="gelu_new", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, @@ -62,7 +62,7 @@ class FlaxBigBirdModelTester(unittest.TestCase): attention_type="block_sparse", use_bias=True, rescale_embeddings=False, - block_size=4, + block_size=2, num_random_blocks=3, ): self.parent = parent @@ -156,10 +156,30 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase): def setUp(self): self.model_tester = FlaxBigBirdModelTester(self) + @slow + # copied from `test_modeling_flax_common` because it takes much longer than other models + def test_from_pretrained_save_pretrained(self): + super().test_from_pretrained_save_pretrained() + + @slow + # copied from `test_modeling_flax_common` because it takes much longer than other models + def test_from_pretrained_with_no_automatic_init(self): + super().test_from_pretrained_with_no_automatic_init() + + @slow + # copied from `test_modeling_flax_common` because it takes much longer than other models + def test_no_automatic_init(self): + super().test_no_automatic_init() + + @slow + # copied from `test_modeling_flax_common` because it takes much longer than other models + def test_hidden_states_output(self): + super().test_hidden_states_output() + @slow def test_model_from_pretrained(self): for model_class_name in self.all_model_classes: - model = model_class_name.from_pretrained("google/bigbird-roberta-base", from_pt=True) + model = model_class_name.from_pretrained("google/bigbird-roberta-base") outputs = model(np.ones((1, 1))) self.assertIsNotNone(outputs)