From 0bf34b1c9f1311d5bed914b0e631db4ef0c65089 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 28 Apr 2023 17:00:13 +0200 Subject: [PATCH] Skip pt/flax equivalence tests in pytorch `bigbird` test file (#23040) skip Co-authored-by: ydshieh --- tests/models/big_bird/test_modeling_big_bird.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/models/big_bird/test_modeling_big_bird.py b/tests/models/big_bird/test_modeling_big_bird.py index b552473ec4..69455935e5 100644 --- a/tests/models/big_bird/test_modeling_big_bird.py +++ b/tests/models/big_bird/test_modeling_big_bird.py @@ -20,7 +20,7 @@ import unittest from transformers import BigBirdConfig, is_torch_available from transformers.models.auto import get_values from transformers.models.big_bird.tokenization_big_bird import BigBirdTokenizer -from transformers.testing_utils import require_torch, slow, torch_device +from transformers.testing_utils import is_pt_flax_cross_test, require_torch, slow, torch_device from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask @@ -618,6 +618,20 @@ class BigBirdModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) else: super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes) + @is_pt_flax_cross_test + @unittest.skip( + reason="Current Pytorch implementation has bug with random attention -> it always uses it not matter if we are in eval/train mode" + ) + def test_equivalence_flax_to_pt(self): + pass + + @is_pt_flax_cross_test + @unittest.skip( + reason="Current Pytorch implementation has bug with random attention -> it always uses it not matter if we are in eval/train mode" + ) + def test_equivalence_pt_to_flax(self): + pass + @require_torch @slow