device agnostic pipelines testing (#27129)
* device agnostic pipelines testing * pass torch_device
This commit is contained in:
@@ -27,9 +27,6 @@ from transformers.tokenization_utils import TruncationStrategy
|
||||
from .test_pipelines_common import ANY
|
||||
|
||||
|
||||
DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class SummarizationPipelineTests(unittest.TestCase):
|
||||
model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
@@ -106,7 +103,7 @@ class SummarizationPipelineTests(unittest.TestCase):
|
||||
@require_torch
|
||||
@slow
|
||||
def test_integration_torch_summarization(self):
|
||||
summarizer = pipeline(task="summarization", device=DEFAULT_DEVICE_NUM)
|
||||
summarizer = pipeline(task="summarization", device=torch_device)
|
||||
cnn_article = (
|
||||
" (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
|
||||
" Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
|
||||
|
||||
Reference in New Issue
Block a user