From 76d24f1a83f0193fa606fc116514f9f86356bfed Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Mon, 17 Apr 2023 12:41:55 +0200 Subject: [PATCH] Fix `test_word_time_stamp_integration` for `Wav2Vec2ProcessorWithLMTest` (#22800) * fix --------- Co-authored-by: ydshieh --- .../test_processor_wav2vec2_with_lm.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py b/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py index a98ea55d0b..bd1582ceb1 100644 --- a/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py +++ b/tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py @@ -23,7 +23,6 @@ from pathlib import Path import datasets import numpy as np from datasets import load_dataset -from packaging import version from parameterized import parameterized from transformers import AutoProcessor @@ -461,7 +460,7 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): for d in output["word_offsets"] ] - EXPECTED_TEXT = "WHY DOES A MILE SANDRA LOOK LIKE SHE WANTS TO CONSUME JOHN SNOW ON THE RIVER AT THE WALL" + EXPECTED_TEXT = "WHY DOES MILISANDRA LOOK LIKE SHE WANTS TO CONSUME JOHN SNOW ON THE RIVER AT THE WALL" # output words self.assertEqual(" ".join(self.get_from_offsets(word_time_stamps, "word")), EXPECTED_TEXT) @@ -472,14 +471,8 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase): end_times = torch.tensor(self.get_from_offsets(word_time_stamps, "end_time")) # fmt: off - expected_start_tensor = torch.tensor([1.42, 1.64, 2.12, 2.26, 2.54, 3.0, 3.24, 3.6, 3.8, 4.1, 4.26, 4.94, 5.28, 5.66, 5.78, 5.94, 6.32, 6.54, 6.66]) - - # TODO(Patrick): This if-else version statement should be removed once - # https://github.com/huggingface/datasets/issues/4889 is resolved - if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12.0"): - expected_end_tensor = torch.tensor([1.54, 1.88, 2.14, 2.46, 2.9, 3.16, 3.54, 3.72, 4.02, 4.18, 4.76, 5.16, 5.56, 5.7, 5.86, 6.2, 6.38, 6.62, 6.94]) - else: - expected_end_tensor = torch.tensor([1.54, 1.88, 2.14, 2.46, 2.9, 3.18, 3.54, 3.72, 4.02, 4.18, 4.76, 5.16, 5.56, 5.7, 5.86, 6.2, 6.38, 6.62, 6.94]) + expected_start_tensor = torch.tensor([1.4199, 1.6599, 2.2599, 3.0, 3.24, 3.5999, 3.7999, 4.0999, 4.26, 4.94, 5.28, 5.6599, 5.78, 5.94, 6.32, 6.5399, 6.6599]) + expected_end_tensor = torch.tensor([1.5399, 1.8999, 2.9, 3.16, 3.5399, 3.72, 4.0199, 4.1799, 4.76, 5.1599, 5.5599, 5.6999, 5.86, 6.1999, 6.38, 6.6199, 6.94]) # fmt: on self.assertTrue(torch.allclose(start_times, expected_start_tensor, atol=0.01))