Fix PvtModelIntegrationTest::test_inference_fp16 (#25106)

update

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-07-26 14:57:44 +02:00
committed by GitHub
parent ee63520a7b
commit 31acba5697

View File

@@ -317,14 +317,13 @@ class PvtModelIntegrationTest(unittest.TestCase):
r"""
A small test to make sure that inference work in half precision without any problem.
"""
model = PvtForImageClassification.from_pretrained(
"Zetatech/pvt-tiny-224", torch_dtype=torch.float16, device_map="auto"
)
model = PvtForImageClassification.from_pretrained("Zetatech/pvt-tiny-224", torch_dtype=torch.float16)
model.to(torch_device)
image_processor = PvtImageProcessor(size=224)
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt")
pixel_values = inputs.pixel_values.to(torch_device).astype(torch.float16)
pixel_values = inputs.pixel_values.to(torch_device, dtype=torch.float16)
# forward pass to make sure inference works in fp16
with torch.no_grad():