[tests] make cuda-only tests device-agnostic (#35607)
* intial commit * remove unrelated files * further remove * Update test_trainer.py * fix style
This commit is contained in:
@@ -27,6 +27,7 @@ from parameterized import parameterized
|
||||
from transformers import CONFIG_MAPPING, Blip2Config, Blip2QFormerConfig, Blip2VisionConfig
|
||||
from transformers.testing_utils import (
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_fp16,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_accelerator,
|
||||
@@ -1565,7 +1566,7 @@ class Blip2TextModelWithProjectionTest(ModelTesterMixin, unittest.TestCase):
|
||||
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_model_from_pretrained(self):
|
||||
model_name = "Salesforce/blip2-itm-vit-g"
|
||||
model = Blip2TextModelWithProjection.from_pretrained(model_name)
|
||||
@@ -2191,7 +2192,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
self.assertTrue(generated_text_expanded == generated_text)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_inference_itm(self):
|
||||
model_name = "Salesforce/blip2-itm-vit-g"
|
||||
processor = Blip2Processor.from_pretrained(model_name)
|
||||
@@ -2210,7 +2211,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(torch.nn.Softmax()(out_itm[0].cpu()), expected_scores, rtol=1e-3, atol=1e-3))
|
||||
self.assertTrue(torch.allclose(out[0].cpu(), torch.Tensor([[0.4406]]), rtol=1e-3, atol=1e-3))
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@require_torch_fp16
|
||||
def test_inference_itm_fp16(self):
|
||||
model_name = "Salesforce/blip2-itm-vit-g"
|
||||
@@ -2232,7 +2233,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
self.assertTrue(torch.allclose(out[0].cpu().float(), torch.Tensor([[0.4406]]), rtol=1e-3, atol=1e-3))
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@require_torch_fp16
|
||||
def test_inference_vision_with_projection_fp16(self):
|
||||
model_name = "Salesforce/blip2-itm-vit-g"
|
||||
@@ -2256,7 +2257,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
|
||||
]
|
||||
self.assertTrue(np.allclose(out.image_embeds[0][0][:6].tolist(), expected_image_embeds, atol=1e-3))
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@require_torch_fp16
|
||||
def test_inference_text_with_projection_fp16(self):
|
||||
model_name = "Salesforce/blip2-itm-vit-g"
|
||||
|
||||
Reference in New Issue
Block a user