From 31acba56976a7e2746e05f6050984c64006e572b Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Wed, 26 Jul 2023 14:57:44 +0200 Subject: [PATCH] Fix `PvtModelIntegrationTest::test_inference_fp16` (#25106) update Co-authored-by: ydshieh --- tests/models/pvt/test_modeling_pvt.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/models/pvt/test_modeling_pvt.py b/tests/models/pvt/test_modeling_pvt.py index afc6fce79e..d2290e0a02 100644 --- a/tests/models/pvt/test_modeling_pvt.py +++ b/tests/models/pvt/test_modeling_pvt.py @@ -317,14 +317,13 @@ class PvtModelIntegrationTest(unittest.TestCase): r""" A small test to make sure that inference work in half precision without any problem. """ - model = PvtForImageClassification.from_pretrained( - "Zetatech/pvt-tiny-224", torch_dtype=torch.float16, device_map="auto" - ) + model = PvtForImageClassification.from_pretrained("Zetatech/pvt-tiny-224", torch_dtype=torch.float16) + model.to(torch_device) image_processor = PvtImageProcessor(size=224) image = prepare_img() inputs = image_processor(images=image, return_tensors="pt") - pixel_values = inputs.pixel_values.to(torch_device).astype(torch.float16) + pixel_values = inputs.pixel_values.to(torch_device, dtype=torch.float16) # forward pass to make sure inference works in fp16 with torch.no_grad():