[Wav2vec2 + LM Test] Improve wav2vec2 with lm tests and make torch version dependent for now (#18749)
* add first generation tutorial * remove generation * make version dependent expected values * Apply suggestions from code review * Update tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py * fix typo
This commit is contained in:
committed by
GitHub
parent
8869bf41fe
commit
62ceb4d661
@@ -23,6 +23,7 @@ from pathlib import Path
|
|||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
|
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
|
||||||
@@ -435,21 +436,19 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
|||||||
self.assertEqual(" ".join(self.get_from_offsets(word_time_stamps, "word")), output.text)
|
self.assertEqual(" ".join(self.get_from_offsets(word_time_stamps, "word")), output.text)
|
||||||
|
|
||||||
# output times
|
# output times
|
||||||
start_times = [round(x, 2) for x in self.get_from_offsets(word_time_stamps, "start_time")]
|
start_times = torch.tensor(self.get_from_offsets(word_time_stamps, "start_time"))
|
||||||
end_times = [round(x, 2) for x in self.get_from_offsets(word_time_stamps, "end_time")]
|
end_times = torch.tensor(self.get_from_offsets(word_time_stamps, "end_time"))
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
self.assertListEqual(
|
expected_start_tensor = torch.tensor([1.42, 1.64, 2.12, 2.26, 2.54, 3.0, 3.24, 3.6, 3.8, 4.1, 4.26, 4.94, 5.28, 5.66, 5.78, 5.94, 6.32, 6.54, 6.66])
|
||||||
start_times,
|
|
||||||
[
|
|
||||||
1.42, 1.64, 2.12, 2.26, 2.54, 3.0, 3.24, 3.6, 3.8, 4.1, 4.26, 4.94, 5.28, 5.66, 5.78, 5.94, 6.32, 6.54, 6.66,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertListEqual(
|
# TODO(Patrick): This if-else version statement should be removed once
|
||||||
end_times,
|
# https://github.com/huggingface/datasets/issues/4889 is resolved
|
||||||
[
|
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12.0"):
|
||||||
1.54, 1.88, 2.14, 2.46, 2.9, 3.18, 3.54, 3.72, 4.02, 4.18, 4.76, 5.16, 5.56, 5.7, 5.86, 6.2, 6.38, 6.62, 6.94,
|
expected_end_tensor = torch.tensor([1.54, 1.88, 2.14, 2.46, 2.9, 3.16, 3.54, 3.72, 4.02, 4.18, 4.76, 5.16, 5.56, 5.7, 5.86, 6.2, 6.38, 6.62, 6.94])
|
||||||
],
|
else:
|
||||||
)
|
expected_end_tensor = torch.tensor([1.54, 1.88, 2.14, 2.46, 2.9, 3.18, 3.54, 3.72, 4.02, 4.18, 4.76, 5.16, 5.56, 5.7, 5.86, 6.2, 6.38, 6.62, 6.94])
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(start_times, expected_start_tensor, atol=0.01))
|
||||||
|
self.assertTrue(torch.allclose(end_times, expected_end_tensor, atol=0.01))
|
||||||
|
|||||||
Reference in New Issue
Block a user