This commit is contained in:
Patrick von Platen
2021-08-30 12:02:08 +02:00
committed by GitHub
parent 4046e66e40
commit 4362ee298a

View File

@@ -988,6 +988,9 @@ class RagModelIntegrationTests(unittest.TestCase):
torch_device torch_device
) )
if torch_device == "cuda":
rag_token.half()
input_dict = tokenizer( input_dict = tokenizer(
self.test_data_questions, self.test_data_questions,
return_tensors="pt", return_tensors="pt",