From 03c14a515f637369c841d5acf0349a47a7996a0a Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Thu, 17 Mar 2022 10:53:57 +0100 Subject: [PATCH] [Tests] Fix DiT test (#16218) * Fix device * Clean up Co-authored-by: Niels Rogge --- tests/dit/test_modeling_dit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dit/test_modeling_dit.py b/tests/dit/test_modeling_dit.py index ad78d1b172..e7e1474f17 100644 --- a/tests/dit/test_modeling_dit.py +++ b/tests/dit/test_modeling_dit.py @@ -43,7 +43,7 @@ class DiTIntegrationTest(unittest.TestCase): image = dataset["train"][0]["image"].convert("RGB") - inputs = feature_extractor(image, return_tensors="pt") + inputs = feature_extractor(image, return_tensors="pt").to(torch_device) # forward pass with torch.no_grad():