🚨🚨🚨 fix(Mask2Former): torch export 🚨🚨🚨 (#34393)
* fix(Mask2Former): torch export Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai> * revert level_start_index and create a level_start_index_list Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai> * Add a comment to explain the level_start_index_list Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai> * Address comment Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai> * add torch.export.export test Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai> * rename arg Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai> * remove spatial_shapes Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai> * Use the version check from pytorch_utils Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai> * [run_slow] mask2former Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai> --------- Signed-off-by: Phillip Kuznetsov <philkuz@gimletlabs.ai>
This commit is contained in:
committed by
GitHub
parent
581524389a
commit
5fa4f64605
@@ -20,6 +20,7 @@ import numpy as np
|
||||
|
||||
from tests.test_modeling_common import floats_tensor
|
||||
from transformers import Mask2FormerConfig, is_torch_available, is_vision_available
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
|
||||
from transformers.testing_utils import (
|
||||
require_timm,
|
||||
require_torch,
|
||||
@@ -481,3 +482,28 @@ class Mask2FormerModelIntegrationTest(unittest.TestCase):
|
||||
outputs = model(**inputs)
|
||||
|
||||
self.assertTrue(outputs.loss is not None)
|
||||
|
||||
def test_export(self):
|
||||
if not is_torch_greater_or_equal_than_2_4:
|
||||
self.skipTest(reason="This test requires torch >= 2.4 to run.")
|
||||
model = Mask2FormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
|
||||
image_processor = self.default_image_processor
|
||||
image = prepare_img()
|
||||
inputs = image_processor(image, return_tensors="pt").to(torch_device)
|
||||
|
||||
exported_program = torch.export.export(
|
||||
model,
|
||||
args=(inputs["pixel_values"], inputs["pixel_mask"]),
|
||||
strict=True,
|
||||
)
|
||||
with torch.no_grad():
|
||||
eager_outputs = model(**inputs)
|
||||
exported_outputs = exported_program.module().forward(inputs["pixel_values"], inputs["pixel_mask"])
|
||||
self.assertEqual(eager_outputs.masks_queries_logits.shape, exported_outputs.masks_queries_logits.shape)
|
||||
self.assertTrue(
|
||||
torch.allclose(eager_outputs.masks_queries_logits, exported_outputs.masks_queries_logits, atol=TOLERANCE)
|
||||
)
|
||||
self.assertEqual(eager_outputs.class_queries_logits.shape, exported_outputs.class_queries_logits.shape)
|
||||
self.assertTrue(
|
||||
torch.allclose(eager_outputs.class_queries_logits, exported_outputs.class_queries_logits, atol=TOLERANCE)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user