Fix PvtModelIntegrationTest::test_inference_fp16 (#25106)
update Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -317,14 +317,13 @@ class PvtModelIntegrationTest(unittest.TestCase):
|
|||||||
r"""
|
r"""
|
||||||
A small test to make sure that inference work in half precision without any problem.
|
A small test to make sure that inference work in half precision without any problem.
|
||||||
"""
|
"""
|
||||||
model = PvtForImageClassification.from_pretrained(
|
model = PvtForImageClassification.from_pretrained("Zetatech/pvt-tiny-224", torch_dtype=torch.float16)
|
||||||
"Zetatech/pvt-tiny-224", torch_dtype=torch.float16, device_map="auto"
|
model.to(torch_device)
|
||||||
)
|
|
||||||
image_processor = PvtImageProcessor(size=224)
|
image_processor = PvtImageProcessor(size=224)
|
||||||
|
|
||||||
image = prepare_img()
|
image = prepare_img()
|
||||||
inputs = image_processor(images=image, return_tensors="pt")
|
inputs = image_processor(images=image, return_tensors="pt")
|
||||||
pixel_values = inputs.pixel_values.to(torch_device).astype(torch.float16)
|
pixel_values = inputs.pixel_values.to(torch_device, dtype=torch.float16)
|
||||||
|
|
||||||
# forward pass to make sure inference works in fp16
|
# forward pass to make sure inference works in fp16
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|||||||
Reference in New Issue
Block a user