@@ -142,6 +142,20 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||
self.assertTrue(np.allclose(mel_1, mel_2))
|
||||
self.assertEqual(dict_first, dict_second)
|
||||
|
||||
def test_feat_extract_from_pretrained_kwargs(self):
|
||||
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
|
||||
check_json_file_has_correct_format(saved_file)
|
||||
feat_extract_second = self.feature_extraction_class.from_pretrained(
|
||||
tmpdirname, feature_size=2 * self.feat_extract_dict["feature_size"]
|
||||
)
|
||||
|
||||
mel_1 = feat_extract_first.mel_filters
|
||||
mel_2 = feat_extract_second.mel_filters
|
||||
self.assertTrue(2 * mel_1.shape[1] == mel_2.shape[1])
|
||||
|
||||
def test_call(self):
|
||||
# Tests that all call wrap to encode_plus and batch_encode_plus
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
|
||||
Reference in New Issue
Block a user