From e30078b544f0d15eb398c6bb9dc7089f96e924e2 Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Mon, 8 Nov 2021 14:15:56 +0300 Subject: [PATCH] [Tests] Update audio classification tests to support torch 1.10 (#14318) --- tests/test_pipelines_audio_classification.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/test_pipelines_audio_classification.py b/tests/test_pipelines_audio_classification.py index 1b0ad5d2cb..f01825dd99 100644 --- a/tests/test_pipelines_audio_classification.py +++ b/tests/test_pipelines_audio_classification.py @@ -82,7 +82,6 @@ class AudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest ], ) - @unittest.skip("Skip tests while investigating difference between PyTorch 1.9 and 1.10") @require_torch def test_small_model_pt(self): model = "anton-l/wav2vec2-random-tiny-classifier" @@ -94,10 +93,10 @@ class AudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest self.assertEqual( nested_simplify(output, decimals=4), [ - {"score": 0.0843, "label": "on"}, - {"score": 0.0840, "label": "left"}, - {"score": 0.0837, "label": "off"}, - {"score": 0.0835, "label": "yes"}, + {"score": 0.0842, "label": "no"}, + {"score": 0.0838, "label": "up"}, + {"score": 0.0837, "label": "go"}, + {"score": 0.0834, "label": "right"}, ], ) @@ -117,7 +116,7 @@ class AudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest self.assertEqual( nested_simplify(output, decimals=4), [ - {"score": 0.981, "label": "go"}, + {"score": 0.9809, "label": "go"}, {"score": 0.0073, "label": "up"}, {"score": 0.0064, "label": "_unknown_"}, {"score": 0.0015, "label": "down"},