device agnostic pipelines testing (#27129)

* device agnostic pipelines testing

* pass torch_device
This commit is contained in:
Hz, Ji
2023-10-31 22:46:31 +08:00
committed by GitHub
parent 08fadc8085
commit f53041a753
10 changed files with 64 additions and 58 deletions

View File

@@ -27,9 +27,6 @@ from transformers.tokenization_utils import TruncationStrategy
from .test_pipelines_common import ANY
DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
@is_pipeline_test
class SummarizationPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
@@ -106,7 +103,7 @@ class SummarizationPipelineTests(unittest.TestCase):
@require_torch
@slow
def test_integration_torch_summarization(self):
summarizer = pipeline(task="summarization", device=DEFAULT_DEVICE_NUM)
summarizer = pipeline(task="summarization", device=torch_device)
cnn_article = (
" (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
" Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"