From b4eef63a1de97b9bbd8d54b83ede16e34afe3529 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 10 Jun 2022 18:48:58 +0200 Subject: [PATCH] [Data2Vec] Speed up test (#17660) --- tests/models/data2vec/test_modeling_data2vec_audio.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/data2vec/test_modeling_data2vec_audio.py b/tests/models/data2vec/test_modeling_data2vec_audio.py index 24e2cd918d..e3fb96097d 100644 --- a/tests/models/data2vec/test_modeling_data2vec_audio.py +++ b/tests/models/data2vec/test_modeling_data2vec_audio.py @@ -535,7 +535,7 @@ class Data2VecAudioModelTest(ModelTesterMixin, unittest.TestCase): def test_mask_feature_prob_ctc(self): model = Data2VecAudioForCTC.from_pretrained( - "facebook/data2vec-audio-base-960h", mask_feature_prob=0.2, mask_feature_length=2 + "hf-internal-testing/tiny-random-data2vec-seq-class", mask_feature_prob=0.2, mask_feature_length=2 ) model.to(torch_device).train() processor = Wav2Vec2Processor.from_pretrained( @@ -554,7 +554,7 @@ class Data2VecAudioModelTest(ModelTesterMixin, unittest.TestCase): attention_mask=batch["attention_mask"].to(torch_device), ).logits - self.assertEqual(logits.shape, (4, 299, 32)) + self.assertEqual(logits.shape, (4, 1498, 32)) def test_mask_time_prob_ctc(self): model = Data2VecAudioForCTC.from_pretrained(