Fix PvtModelIntegrationTest::test_inference_fp16 (#25106)
update Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -317,14 +317,13 @@ class PvtModelIntegrationTest(unittest.TestCase):
|
||||
r"""
|
||||
A small test to make sure that inference work in half precision without any problem.
|
||||
"""
|
||||
model = PvtForImageClassification.from_pretrained(
|
||||
"Zetatech/pvt-tiny-224", torch_dtype=torch.float16, device_map="auto"
|
||||
)
|
||||
model = PvtForImageClassification.from_pretrained("Zetatech/pvt-tiny-224", torch_dtype=torch.float16)
|
||||
model.to(torch_device)
|
||||
image_processor = PvtImageProcessor(size=224)
|
||||
|
||||
image = prepare_img()
|
||||
inputs = image_processor(images=image, return_tensors="pt")
|
||||
pixel_values = inputs.pixel_values.to(torch_device).astype(torch.float16)
|
||||
pixel_values = inputs.pixel_values.to(torch_device, dtype=torch.float16)
|
||||
|
||||
# forward pass to make sure inference works in fp16
|
||||
with torch.no_grad():
|
||||
|
||||
Reference in New Issue
Block a user