From f2a2616b7462c2f213dbc93332ddf81cae2ef874 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Mon, 6 Mar 2023 18:07:31 +0100 Subject: [PATCH] Update expected values for `test_xglm_sample` (#21975) update expected values for xglm Co-authored-by: ydshieh --- tests/models/xglm/test_modeling_xglm.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/models/xglm/test_modeling_xglm.py b/tests/models/xglm/test_modeling_xglm.py index 010fef113a..9fcc25b6d2 100644 --- a/tests/models/xglm/test_modeling_xglm.py +++ b/tests/models/xglm/test_modeling_xglm.py @@ -428,8 +428,14 @@ class XGLMModelLanguageGenerationTest(unittest.TestCase): output_ids = model.generate(input_ids, do_sample=True, num_beams=1) output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True) - EXPECTED_OUTPUT_STR = "Today is a nice day and the sun is shining. A nice day with warm rainy" - self.assertEqual(output_str, EXPECTED_OUTPUT_STR) + EXPECTED_OUTPUT_STRS = [ + # TODO: remove this once we move to torch 2.0 + # torch 1.13.1 + cu116 + "Today is a nice day and the sun is shining. A nice day with warm rainy", + # torch 2.0 + cu117 + "Today is a nice day and the water is still cold. We just stopped off for some fresh", + ] + self.assertIn(output_str, EXPECTED_OUTPUT_STRS) @slow def test_xglm_sample_max_time(self):