Skip pt/flax equivalence tests in pytorch bigbird test file (#23040)

skip

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-04-28 17:00:13 +02:00
committed by GitHub
parent 4d0ea3d269
commit 0bf34b1c9f

View File

@@ -20,7 +20,7 @@ import unittest
from transformers import BigBirdConfig, is_torch_available
from transformers.models.auto import get_values
from transformers.models.big_bird.tokenization_big_bird import BigBirdTokenizer
from transformers.testing_utils import require_torch, slow, torch_device
from transformers.testing_utils import is_pt_flax_cross_test, require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
@@ -618,6 +618,20 @@ class BigBirdModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
else:
super().check_pt_flax_outputs(fx_outputs, pt_outputs, model_class, tol, name, attributes)
@is_pt_flax_cross_test
@unittest.skip(
reason="Current Pytorch implementation has bug with random attention -> it always uses it not matter if we are in eval/train mode"
)
def test_equivalence_flax_to_pt(self):
pass
@is_pt_flax_cross_test
@unittest.skip(
reason="Current Pytorch implementation has bug with random attention -> it always uses it not matter if we are in eval/train mode"
)
def test_equivalence_pt_to_flax(self):
pass
@require_torch
@slow