From e39172ecab1d6b57885853d24e0fc53d3d6956b3 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 13 Jun 2025 15:19:41 +0200 Subject: [PATCH] Fix `llava_next` tests (#38813) * fix * fix --------- Co-authored-by: ydshieh --- tests/models/llava_next/test_modeling_llava_next.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py index 93142f1da6..cb573913e4 100644 --- a/tests/models/llava_next/test_modeling_llava_next.py +++ b/tests/models/llava_next/test_modeling_llava_next.py @@ -392,7 +392,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase): load_in_4bit=True, ) - inputs = self.processor(images=self.image, text=self.prompt, return_tensors="pt") + inputs = self.processor(images=self.image, text=self.prompt, return_tensors="pt").to(torch_device) # verify inputs against original implementation filepath = hf_hub_download( @@ -415,11 +415,13 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase): ) check_torch_load_is_safe() original_pixel_values = torch.load(filepath, map_location="cpu", weights_only=True) - assert torch.allclose(original_pixel_values, inputs.pixel_values.half()) + assert torch.allclose( + original_pixel_values, inputs.pixel_values.to(device="cpu", dtype=original_pixel_values.dtype) + ) # verify generation output = model.generate(**inputs, max_new_tokens=100) - EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays values for multiple quantitative variables represented on axes starting from the same point. This particular radar chart is showing the performance of various models or systems across different metrics or datasets.\n\nThe chart is divided into several sections, each representing a different model or dataset. The axes represent different metrics or datasets, such as "MMM-Vet," "MMM-Bench," "L' # fmt: skip + EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays values for multiple quantitative variables represented on axes starting from the same point. This particular radar chart is showing the performance of various models or systems across different metrics or datasets.\n\nThe chart is divided into several sections, each representing a different model or dataset. The axes represent different metrics or datasets, such as "MMM-Vet," "MMM-Bench," "L' self.assertEqual( self.processor.decode(output[0], skip_special_tokens=True), @@ -511,7 +513,7 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase): # verify generation output = model.generate(**inputs, max_new_tokens=50) - EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image shows two deer, likely fawns, in a grassy area with trees in the background. The setting appears to be a forest or woodland, and the time of day seems to be either dawn or dusk, given the soft' # fmt: skip + EXPECTED_DECODED_TEXT = "[INST] \nWhat is shown in this image? [/INST] The image shows two deer, likely fawns, in a grassy area with trees in the background. The setting appears to be a forest or woodland, and the photo is taken during what seems to be either dawn or dusk, given" self.assertEqual( self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT,