Expectation fixes and added AMD expectations (#38729)
This commit is contained in:
@@ -17,7 +17,9 @@ import unittest
|
||||
|
||||
from transformers import XGLMConfig, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
cleanup,
|
||||
is_torch_greater_or_equal,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_fp16,
|
||||
@@ -422,13 +424,21 @@ class XGLMModelLanguageGenerationTest(unittest.TestCase):
|
||||
output_ids = model.generate(input_ids, do_sample=True, num_beams=1)
|
||||
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
||||
|
||||
EXPECTED_OUTPUT_STRS = [
|
||||
# torch 2.6
|
||||
"Today is a nice day and the water is still cold. We just stopped off for some fresh coffee. This place looks like a",
|
||||
# torch 2.7
|
||||
"Today is a nice day and the sun is shining. A nice day with warm rainy and windy weather today.",
|
||||
]
|
||||
self.assertIn(output_str, EXPECTED_OUTPUT_STRS)
|
||||
if is_torch_greater_or_equal("2.7.0"):
|
||||
cuda_expectation = (
|
||||
"Today is a nice day and the sun is shining. A nice day with warm rainy and windy weather today."
|
||||
)
|
||||
else:
|
||||
cuda_expectation = "Today is a nice day and the water is still cold. We just stopped off for some fresh coffee. This place looks like a"
|
||||
|
||||
expected_output_strings = Expectations(
|
||||
{
|
||||
("rocm", (9, 5)): "Today is a nice day and the sun is shining. A nice day with warm rainy and windy weather today.",
|
||||
("cuda", None): cuda_expectation,
|
||||
}
|
||||
) # fmt: skip
|
||||
EXPECTED_OUTPUT_STR = expected_output_strings.get_expectation()
|
||||
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
|
||||
|
||||
@require_torch_accelerator
|
||||
@require_torch_fp16
|
||||
|
||||
Reference in New Issue
Block a user