From 2adc8c926aac1fd1ef97183ee4eaee78ac787aca Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Thu, 11 Mar 2021 12:56:12 -0500 Subject: [PATCH] W2v2 test require torch (#10665) * Adds a @require_torch to a test that requires it * Tokenizer too * Style --- tests/test_feature_extraction_wav2vec2.py | 3 ++- tests/test_tokenization_wav2vec2.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_feature_extraction_wav2vec2.py b/tests/test_feature_extraction_wav2vec2.py index 771974a398..d55d951ee3 100644 --- a/tests/test_feature_extraction_wav2vec2.py +++ b/tests/test_feature_extraction_wav2vec2.py @@ -21,7 +21,7 @@ import unittest import numpy as np from transformers import WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, Wav2Vec2Config, Wav2Vec2FeatureExtractor -from transformers.testing_utils import slow +from transformers.testing_utils import require_torch, slow from .test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin @@ -134,6 +134,7 @@ class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest _check_zero_mean_unit_variance(input_values[2]) @slow + @require_torch def test_pretrained_checkpoints_are_set_correctly(self): # this test makes sure that models that are using # group norm don't have their feature extractor return the diff --git a/tests/test_tokenization_wav2vec2.py b/tests/test_tokenization_wav2vec2.py index f7a5e4da16..002bf4b225 100644 --- a/tests/test_tokenization_wav2vec2.py +++ b/tests/test_tokenization_wav2vec2.py @@ -30,7 +30,7 @@ from transformers import ( Wav2Vec2Tokenizer, ) from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES -from transformers.testing_utils import slow +from transformers.testing_utils import require_torch, slow from .test_tokenization_common import TokenizerTesterMixin @@ -340,6 +340,7 @@ class Wav2Vec2TokenizerTest(unittest.TestCase): self.assertListEqual(processed.attention_mask.sum(-1).tolist(), [800, 1000, 1200]) @slow + @require_torch def test_pretrained_checkpoints_are_set_correctly(self): # this test makes sure that models that are using # group norm don't have their tokenizer return the