device agnostic pipelines testing (#27129)

* device agnostic pipelines testing

* pass torch_device
This commit is contained in:
Hz, Ji
2023-10-31 22:46:31 +08:00
committed by GitHub
parent 08fadc8085
commit f53041a753
10 changed files with 64 additions and 58 deletions

View File

@@ -40,15 +40,17 @@ from transformers.testing_utils import (
USER,
CaptureLogger,
RequestCounter,
backend_empty_cache,
is_pipeline_test,
is_staging_test,
nested_simplify,
require_tensorflow_probability,
require_tf,
require_torch,
require_torch_gpu,
require_torch_accelerator,
require_torch_or_tf,
slow,
torch_device,
)
from transformers.utils import direct_transformers_import, is_tf_available, is_torch_available
from transformers.utils import logging as transformers_logging
@@ -511,7 +513,7 @@ class PipelineUtilsTest(unittest.TestCase):
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
@slow
@require_tf
@@ -541,20 +543,20 @@ class PipelineUtilsTest(unittest.TestCase):
# clean-up as much as possible GPU memory occupied by PyTorch
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
@slow
@require_torch
@require_torch_gpu
def test_pipeline_cuda(self):
pipe = pipeline("text-generation", device="cuda")
@require_torch_accelerator
def test_pipeline_accelerator(self):
pipe = pipeline("text-generation", device=torch_device)
_ = pipe("Hello")
@slow
@require_torch
@require_torch_gpu
def test_pipeline_cuda_indexed(self):
pipe = pipeline("text-generation", device="cuda:0")
@require_torch_accelerator
def test_pipeline_accelerator_indexed(self):
pipe = pipeline("text-generation", device=torch_device)
_ = pipe("Hello")
@slow