[Data2Vec] Speed up test (#17660)
This commit is contained in:
committed by
GitHub
parent
5e428b71b4
commit
b4eef63a1d
@@ -535,7 +535,7 @@ class Data2VecAudioModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def test_mask_feature_prob_ctc(self):
|
def test_mask_feature_prob_ctc(self):
|
||||||
model = Data2VecAudioForCTC.from_pretrained(
|
model = Data2VecAudioForCTC.from_pretrained(
|
||||||
"facebook/data2vec-audio-base-960h", mask_feature_prob=0.2, mask_feature_length=2
|
"hf-internal-testing/tiny-random-data2vec-seq-class", mask_feature_prob=0.2, mask_feature_length=2
|
||||||
)
|
)
|
||||||
model.to(torch_device).train()
|
model.to(torch_device).train()
|
||||||
processor = Wav2Vec2Processor.from_pretrained(
|
processor = Wav2Vec2Processor.from_pretrained(
|
||||||
@@ -554,7 +554,7 @@ class Data2VecAudioModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
attention_mask=batch["attention_mask"].to(torch_device),
|
attention_mask=batch["attention_mask"].to(torch_device),
|
||||||
).logits
|
).logits
|
||||||
|
|
||||||
self.assertEqual(logits.shape, (4, 299, 32))
|
self.assertEqual(logits.shape, (4, 1498, 32))
|
||||||
|
|
||||||
def test_mask_time_prob_ctc(self):
|
def test_mask_time_prob_ctc(self):
|
||||||
model = Data2VecAudioForCTC.from_pretrained(
|
model = Data2VecAudioForCTC.from_pretrained(
|
||||||
|
|||||||
Reference in New Issue
Block a user