Copied from for test files (#26713)
* copied statement for test files --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -19,6 +19,7 @@ import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import ClapFeatureExtractor
|
||||
from transformers.testing_utils import require_torch, require_torchaudio
|
||||
@@ -110,10 +111,10 @@ class ClapFeatureExtractionTester(unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
# Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest with Whisper->Clap
|
||||
class ClapFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
|
||||
feature_extraction_class = ClapFeatureExtractor
|
||||
|
||||
# Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest.setUp with Whisper->Clap
|
||||
def setUp(self):
|
||||
self.feat_extract_tester = ClapFeatureExtractionTester(self)
|
||||
|
||||
@@ -147,6 +148,7 @@ class ClapFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Tes
|
||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||
|
||||
# Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest.test_double_precision_pad
|
||||
def test_double_precision_pad(self):
|
||||
import torch
|
||||
|
||||
@@ -160,9 +162,8 @@ class ClapFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Tes
|
||||
pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
|
||||
self.assertTrue(pt_processed.input_features.dtype == torch.float32)
|
||||
|
||||
# Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest._load_datasamples
|
||||
def _load_datasamples(self, num_samples):
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
# automatic decoding with librispeech
|
||||
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
|
||||
|
||||
Reference in New Issue
Block a user