From e3669375875c4fd6c8c28b193befde9a9d6e78ce Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 8 Dec 2023 14:55:29 +0100 Subject: [PATCH] Fix 2 tests in `FillMaskPipelineTests` (#27889) * fix * fix * fix * fix * fix --------- Co-authored-by: ydshieh --- tests/pipelines/test_pipelines_fill_mask.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/tests/pipelines/test_pipelines_fill_mask.py b/tests/pipelines/test_pipelines_fill_mask.py index c85797fbb6..571b320d61 100644 --- a/tests/pipelines/test_pipelines_fill_mask.py +++ b/tests/pipelines/test_pipelines_fill_mask.py @@ -216,15 +216,24 @@ class FillMaskPipelineTests(unittest.TestCase): ], ) + dummy_str = "Lorem ipsum dolor sit amet, consectetur adipiscing elit," * 100 outputs = unmasker( - "My name is " + "Lorem ipsum dolor sit amet, consectetur adipiscing elit," * 100, + "My name is " + dummy_str, tokenizer_kwargs={"truncation": True}, ) + simplified = nested_simplify(outputs, decimals=4) self.assertEqual( - nested_simplify(outputs, decimals=6), + [{"sequence": x["sequence"][:100]} for x in simplified], [ - {"sequence": "My name is grouped", "score": 2.2e-05, "token": 38015, "token_str": " grouped"}, - {"sequence": "My name is accuser", "score": 2.1e-05, "token": 25506, "token_str": " accuser"}, + {"sequence": f"My name is,{dummy_str}"[:100]}, + {"sequence": f"My name is:,{dummy_str}"[:100]}, + ], + ) + self.assertEqual( + [{k: x[k] for k in x if k != "sequence"} for x in simplified], + [ + {"score": 0.2819, "token": 6, "token_str": ","}, + {"score": 0.0954, "token": 46686, "token_str": ":,"}, ], )