From c58e6c129a153ca1a5021e5d7e642d00bf011e20 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Sat, 6 Jun 2020 00:52:17 -0400 Subject: [PATCH] [marian tests ] pass device to pipeline (#4815) --- tests/test_modeling_marian.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_modeling_marian.py b/tests/test_modeling_marian.py index 46377da348..c66094c447 100644 --- a/tests/test_modeling_marian.py +++ b/tests/test_modeling_marian.py @@ -233,7 +233,8 @@ class TestMarian_en_ROMANCE(MarianIntegrationTest): self.tokenizer.prepare_translation_batch([""]) def test_pipeline(self): - pipeline = TranslationPipeline(self.model, self.tokenizer, framework="pt") + device = 0 if torch_device == "cuda" else -1 + pipeline = TranslationPipeline(self.model, self.tokenizer, framework="pt", device=device) output = pipeline(self.src_text) self.assertEqual(self.expected_text, [x["translation_text"] for x in output])