device agnostic pipelines testing (#27129)
* device agnostic pipelines testing * pass torch_device
This commit is contained in:
@@ -22,9 +22,10 @@ from transformers.testing_utils import (
|
||||
nested_simplify,
|
||||
require_tf,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_accelerator,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from .test_pipelines_common import ANY
|
||||
@@ -91,7 +92,7 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_small_model_pt_blip2(self):
|
||||
vqa_pipeline = pipeline(
|
||||
"visual-question-answering", model="hf-internal-testing/tiny-random-Blip2ForConditionalGeneration"
|
||||
@@ -112,9 +113,9 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase):
|
||||
"visual-question-answering",
|
||||
model="hf-internal-testing/tiny-random-Blip2ForConditionalGeneration",
|
||||
model_kwargs={"torch_dtype": torch.float16},
|
||||
device=0,
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertEqual(vqa_pipeline.model.device, torch.device(0))
|
||||
self.assertEqual(vqa_pipeline.model.device, torch.device("{}:0".format(torch_device)))
|
||||
self.assertEqual(vqa_pipeline.model.language_model.dtype, torch.float16)
|
||||
self.assertEqual(vqa_pipeline.model.vision_model.dtype, torch.float16)
|
||||
|
||||
@@ -148,15 +149,15 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_large_model_pt_blip2(self):
|
||||
vqa_pipeline = pipeline(
|
||||
"visual-question-answering",
|
||||
model="Salesforce/blip2-opt-2.7b",
|
||||
model_kwargs={"torch_dtype": torch.float16},
|
||||
device=0,
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertEqual(vqa_pipeline.model.device, torch.device(0))
|
||||
self.assertEqual(vqa_pipeline.model.device, torch.device("{}:0".format(torch_device)))
|
||||
self.assertEqual(vqa_pipeline.model.language_model.dtype, torch.float16)
|
||||
|
||||
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
|
||||
|
||||
Reference in New Issue
Block a user