[Tests] Update audio classification tests to support torch 1.10 (#14318)
This commit is contained in:
@@ -82,7 +82,6 @@ class AudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skip("Skip tests while investigating difference between PyTorch 1.9 and 1.10")
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_small_model_pt(self):
|
def test_small_model_pt(self):
|
||||||
model = "anton-l/wav2vec2-random-tiny-classifier"
|
model = "anton-l/wav2vec2-random-tiny-classifier"
|
||||||
@@ -94,10 +93,10 @@ class AudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
nested_simplify(output, decimals=4),
|
nested_simplify(output, decimals=4),
|
||||||
[
|
[
|
||||||
{"score": 0.0843, "label": "on"},
|
{"score": 0.0842, "label": "no"},
|
||||||
{"score": 0.0840, "label": "left"},
|
{"score": 0.0838, "label": "up"},
|
||||||
{"score": 0.0837, "label": "off"},
|
{"score": 0.0837, "label": "go"},
|
||||||
{"score": 0.0835, "label": "yes"},
|
{"score": 0.0834, "label": "right"},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -117,7 +116,7 @@ class AudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
nested_simplify(output, decimals=4),
|
nested_simplify(output, decimals=4),
|
||||||
[
|
[
|
||||||
{"score": 0.981, "label": "go"},
|
{"score": 0.9809, "label": "go"},
|
||||||
{"score": 0.0073, "label": "up"},
|
{"score": 0.0073, "label": "up"},
|
||||||
{"score": 0.0064, "label": "_unknown_"},
|
{"score": 0.0064, "label": "_unknown_"},
|
||||||
{"score": 0.0015, "label": "down"},
|
{"score": 0.0015, "label": "down"},
|
||||||
|
|||||||
Reference in New Issue
Block a user