From 4362ee298a9231bff6e7a37c634029c42fd835ed Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 30 Aug 2021 12:02:08 +0200 Subject: [PATCH] correct (#13304) --- tests/test_modeling_rag.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_modeling_rag.py b/tests/test_modeling_rag.py index 15bbea5237..de9afa46bc 100644 --- a/tests/test_modeling_rag.py +++ b/tests/test_modeling_rag.py @@ -988,6 +988,9 @@ class RagModelIntegrationTests(unittest.TestCase): torch_device ) + if torch_device == "cuda": + rag_token.half() + input_dict = tokenizer( self.test_data_questions, return_tensors="pt",