Update expected values for test_xglm_sample (#21975)

update expected values for xglm

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-03-06 18:07:31 +01:00
committed by GitHub
parent 5d8efc79db
commit f2a2616b74

View File

@@ -428,8 +428,14 @@ class XGLMModelLanguageGenerationTest(unittest.TestCase):
output_ids = model.generate(input_ids, do_sample=True, num_beams=1) output_ids = model.generate(input_ids, do_sample=True, num_beams=1)
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True) output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
EXPECTED_OUTPUT_STR = "Today is a nice day and the sun is shining. A nice day with warm rainy" EXPECTED_OUTPUT_STRS = [
self.assertEqual(output_str, EXPECTED_OUTPUT_STR) # TODO: remove this once we move to torch 2.0
# torch 1.13.1 + cu116
"Today is a nice day and the sun is shining. A nice day with warm rainy",
# torch 2.0 + cu117
"Today is a nice day and the water is still cold. We just stopped off for some fresh",
]
self.assertIn(output_str, EXPECTED_OUTPUT_STRS)
@slow @slow
def test_xglm_sample_max_time(self): def test_xglm_sample_max_time(self):