device agnostic pipelines testing (#27129)

* device agnostic pipelines testing

* pass torch_device
This commit is contained in:
Hz, Ji
2023-10-31 22:46:31 +08:00
committed by GitHub
parent 08fadc8085
commit f53041a753
10 changed files with 64 additions and 58 deletions

View File

@@ -22,9 +22,10 @@ from transformers.testing_utils import (
nested_simplify,
require_tf,
require_torch,
require_torch_gpu,
require_torch_accelerator,
require_vision,
slow,
torch_device,
)
from .test_pipelines_common import ANY
@@ -91,7 +92,7 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase):
)
@require_torch
@require_torch_gpu
@require_torch_accelerator
def test_small_model_pt_blip2(self):
vqa_pipeline = pipeline(
"visual-question-answering", model="hf-internal-testing/tiny-random-Blip2ForConditionalGeneration"
@@ -112,9 +113,9 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase):
"visual-question-answering",
model="hf-internal-testing/tiny-random-Blip2ForConditionalGeneration",
model_kwargs={"torch_dtype": torch.float16},
device=0,
device=torch_device,
)
self.assertEqual(vqa_pipeline.model.device, torch.device(0))
self.assertEqual(vqa_pipeline.model.device, torch.device("{}:0".format(torch_device)))
self.assertEqual(vqa_pipeline.model.language_model.dtype, torch.float16)
self.assertEqual(vqa_pipeline.model.vision_model.dtype, torch.float16)
@@ -148,15 +149,15 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase):
@slow
@require_torch
@require_torch_gpu
@require_torch_accelerator
def test_large_model_pt_blip2(self):
vqa_pipeline = pipeline(
"visual-question-answering",
model="Salesforce/blip2-opt-2.7b",
model_kwargs={"torch_dtype": torch.float16},
device=0,
device=torch_device,
)
self.assertEqual(vqa_pipeline.model.device, torch.device(0))
self.assertEqual(vqa_pipeline.model.device, torch.device("{}:0".format(torch_device)))
self.assertEqual(vqa_pipeline.model.language_model.dtype, torch.float16)
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"