This commit is contained in:
Patrick von Platen
2021-08-30 12:02:08 +02:00
committed by GitHub
parent 4046e66e40
commit 4362ee298a

View File

@@ -988,6 +988,9 @@ class RagModelIntegrationTests(unittest.TestCase):
torch_device
)
if torch_device == "cuda":
rag_token.half()
input_dict = tokenizer(
self.test_data_questions,
return_tensors="pt",