[marian tests ] pass device to pipeline (#4815)
This commit is contained in:
@@ -233,7 +233,8 @@ class TestMarian_en_ROMANCE(MarianIntegrationTest):
|
|||||||
self.tokenizer.prepare_translation_batch([""])
|
self.tokenizer.prepare_translation_batch([""])
|
||||||
|
|
||||||
def test_pipeline(self):
|
def test_pipeline(self):
|
||||||
pipeline = TranslationPipeline(self.model, self.tokenizer, framework="pt")
|
device = 0 if torch_device == "cuda" else -1
|
||||||
|
pipeline = TranslationPipeline(self.model, self.tokenizer, framework="pt", device=device)
|
||||||
output = pipeline(self.src_text)
|
output = pipeline(self.src_text)
|
||||||
self.assertEqual(self.expected_text, [x["translation_text"] for x in output])
|
self.assertEqual(self.expected_text, [x["translation_text"] for x in output])
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user