From fc28c006a612d643505d4a00b07c59023382069c Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Wed, 15 Feb 2023 14:17:26 +0000 Subject: [PATCH] Skip wav2vec2 hubert high mem tests (#21643) * Skip high memory tests * Skip high memory tests * Remove unused import --- tests/models/hubert/test_modeling_tf_hubert.py | 16 ++++++---------- .../wav2vec2/test_modeling_tf_wav2vec2.py | 17 ++++++----------- 2 files changed, 12 insertions(+), 21 deletions(-) diff --git a/tests/models/hubert/test_modeling_tf_hubert.py b/tests/models/hubert/test_modeling_tf_hubert.py index b20119c648..05fcab290b 100644 --- a/tests/models/hubert/test_modeling_tf_hubert.py +++ b/tests/models/hubert/test_modeling_tf_hubert.py @@ -450,19 +450,15 @@ class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase): model = TFHubertModel.from_pretrained("facebook/hubert-large-ls960-ft") self.assertIsNotNone(model) - # We override here as passing a full batch of 13 samples results in OOM errors for CTC + @unittest.skip(reason="Fix me! Hubert hits OOM errors when loss is computed on full batch") def test_dataset_conversion(self): - default_batch_size = self.model_tester.batch_size - self.model_tester.batch_size = 2 - super().test_dataset_conversion() - self.model_tester.batch_size = default_batch_size + # TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC + pass - # We override here as passing a full batch of 13 samples results in OOM errors for CTC + @unittest.skip(reason="Fix me! Hubert hits OOM errors when loss is computed on full batch") def test_keras_fit(self): - default_batch_size = self.model_tester.batch_size - self.model_tester.batch_size = 2 - super().test_keras_fit() - self.model_tester.batch_size = default_batch_size + # TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC + pass @require_tf diff --git a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py index bcfecdb257..8d8b84ceda 100644 --- a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py @@ -385,20 +385,15 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase): model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") self.assertIsNotNone(model) - # We override here as passing a full batch of 13 samples results in OOM errors for CTC - @unittest.skip("Fix me!") + @unittest.skip(reason="Fix me! Wav2Vec2 hits OOM errors when loss is computed on full batch") def test_dataset_conversion(self): - default_batch_size = self.model_tester.batch_size - self.model_tester.batch_size = 2 - super().test_dataset_conversion() - self.model_tester.batch_size = default_batch_size + # TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC + pass - # We override here as passing a full batch of 13 samples results in OOM errors for CTC + @unittest.skip(reason="Fix me! Wav2Vec2 hits OOM errors when loss is computed on full batch") def test_keras_fit(self): - default_batch_size = self.model_tester.batch_size - self.model_tester.batch_size = 2 - super().test_keras_fit() - self.model_tester.batch_size = default_batch_size + # TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC + pass @require_tf