[Tests] Fix DiT test (#16218)

* Fix device

* Clean up

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
NielsRogge
2022-03-17 10:53:57 +01:00
committed by GitHub
parent 73f0a5d1f6
commit 03c14a515f

View File

@@ -43,7 +43,7 @@ class DiTIntegrationTest(unittest.TestCase):
image = dataset["train"][0]["image"].convert("RGB") image = dataset["train"][0]["image"].convert("RGB")
inputs = feature_extractor(image, return_tensors="pt") inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
# forward pass # forward pass
with torch.no_grad(): with torch.no_grad():