From 3499c49c1791689ccc336eb2bee847df1d9227b2 Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Wed, 15 Feb 2023 16:00:50 +0000 Subject: [PATCH] Skipping more high mem tests - Wav2Vec2 Hubert (#21647) Skipping more tests --- tests/models/hubert/test_modeling_tf_hubert.py | 16 ++++++---------- .../wav2vec2/test_modeling_tf_wav2vec2.py | 17 ++++++----------- 2 files changed, 12 insertions(+), 21 deletions(-) diff --git a/tests/models/hubert/test_modeling_tf_hubert.py b/tests/models/hubert/test_modeling_tf_hubert.py index 05fcab290b..15cf801ea6 100644 --- a/tests/models/hubert/test_modeling_tf_hubert.py +++ b/tests/models/hubert/test_modeling_tf_hubert.py @@ -321,19 +321,15 @@ class TFHubertModelTest(TFModelTesterMixin, unittest.TestCase): model = TFHubertModel.from_pretrained("facebook/hubert-base-ls960") self.assertIsNotNone(model) - # We override here as passing a full batch of 13 samples results in OOM errors for CTC + @unittest.skip(reason="Fix me! Hubert hits OOM errors when loss is computed on full batch") def test_dataset_conversion(self): - default_batch_size = self.model_tester.batch_size - self.model_tester.batch_size = 2 - super().test_dataset_conversion() - self.model_tester.batch_size = default_batch_size + # TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC + pass - # We override here as passing a full batch of 13 samples results in OOM errors for CTC + @unittest.skip(reason="Fix me! Hubert hits OOM errors when loss is computed on full batch") def test_keras_fit(self): - default_batch_size = self.model_tester.batch_size - self.model_tester.batch_size = 2 - super().test_keras_fit() - self.model_tester.batch_size = default_batch_size + # TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC + pass @require_tf diff --git a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py index 8d8b84ceda..704cffb834 100644 --- a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py @@ -512,20 +512,15 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase): model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") self.assertIsNotNone(model) - # We override here as passing a full batch of 13 samples results in OOM errors for CTC - @unittest.skip("Fix me!") + @unittest.skip(reason="Fix me! Wav2Vec2 hits OOM errors when loss is computed on full batch") def test_dataset_conversion(self): - default_batch_size = self.model_tester.batch_size - self.model_tester.batch_size = 2 - super().test_dataset_conversion() - self.model_tester.batch_size = default_batch_size + # TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC + pass - # We override here as passing a full batch of 13 samples results in OOM errors for CTC + @unittest.skip(reason="Fix me! Wav2Vec2 hits OOM errors when loss is computed on full batch") def test_keras_fit(self): - default_batch_size = self.model_tester.batch_size - self.model_tester.batch_size = 2 - super().test_keras_fit() - self.model_tester.batch_size = default_batch_size + # TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC + pass @require_tf