Update expected values for test_xglm_sample (#21975)
update expected values for xglm Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -428,8 +428,14 @@ class XGLMModelLanguageGenerationTest(unittest.TestCase):
|
|||||||
output_ids = model.generate(input_ids, do_sample=True, num_beams=1)
|
output_ids = model.generate(input_ids, do_sample=True, num_beams=1)
|
||||||
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||||
|
|
||||||
EXPECTED_OUTPUT_STR = "Today is a nice day and the sun is shining. A nice day with warm rainy"
|
EXPECTED_OUTPUT_STRS = [
|
||||||
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
|
# TODO: remove this once we move to torch 2.0
|
||||||
|
# torch 1.13.1 + cu116
|
||||||
|
"Today is a nice day and the sun is shining. A nice day with warm rainy",
|
||||||
|
# torch 2.0 + cu117
|
||||||
|
"Today is a nice day and the water is still cold. We just stopped off for some fresh",
|
||||||
|
]
|
||||||
|
self.assertIn(output_str, EXPECTED_OUTPUT_STRS)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_xglm_sample_max_time(self):
|
def test_xglm_sample_max_time(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user