[tests] make cuda-only tests device-agnostic (#35607)

* intial commit

* remove unrelated files

* further remove

* Update test_trainer.py

* fix style
This commit is contained in:
Fanli Lin
2025-01-13 21:48:39 +08:00
committed by GitHub
parent e6f9b03464
commit 2fa876d2d8
18 changed files with 57 additions and 47 deletions

View File

@@ -27,6 +27,7 @@ from parameterized import parameterized
from transformers import CONFIG_MAPPING, Blip2Config, Blip2QFormerConfig, Blip2VisionConfig
from transformers.testing_utils import (
require_torch,
require_torch_accelerator,
require_torch_fp16,
require_torch_gpu,
require_torch_multi_accelerator,
@@ -1565,7 +1566,7 @@ class Blip2TextModelWithProjectionTest(ModelTesterMixin, unittest.TestCase):
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
@slow
@require_torch_gpu
@require_torch_accelerator
def test_model_from_pretrained(self):
model_name = "Salesforce/blip2-itm-vit-g"
model = Blip2TextModelWithProjection.from_pretrained(model_name)
@@ -2191,7 +2192,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
self.assertTrue(generated_text_expanded == generated_text)
@require_torch_gpu
@require_torch_accelerator
def test_inference_itm(self):
model_name = "Salesforce/blip2-itm-vit-g"
processor = Blip2Processor.from_pretrained(model_name)
@@ -2210,7 +2211,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(torch.nn.Softmax()(out_itm[0].cpu()), expected_scores, rtol=1e-3, atol=1e-3))
self.assertTrue(torch.allclose(out[0].cpu(), torch.Tensor([[0.4406]]), rtol=1e-3, atol=1e-3))
@require_torch_gpu
@require_torch_accelerator
@require_torch_fp16
def test_inference_itm_fp16(self):
model_name = "Salesforce/blip2-itm-vit-g"
@@ -2232,7 +2233,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
)
self.assertTrue(torch.allclose(out[0].cpu().float(), torch.Tensor([[0.4406]]), rtol=1e-3, atol=1e-3))
@require_torch_gpu
@require_torch_accelerator
@require_torch_fp16
def test_inference_vision_with_projection_fp16(self):
model_name = "Salesforce/blip2-itm-vit-g"
@@ -2256,7 +2257,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
]
self.assertTrue(np.allclose(out.image_embeds[0][0][:6].tolist(), expected_image_embeds, atol=1e-3))
@require_torch_gpu
@require_torch_accelerator
@require_torch_fp16
def test_inference_text_with_projection_fp16(self):
model_name = "Salesforce/blip2-itm-vit-g"