Add mask2former fp16 support (#25093)

* Add mask2former fp16 support

* Clear consistency/quality issues

* Fix consistency/quality (2)

* Add integration test for mask2former (fp16 case)

* Fix code quality

* Add integration test for maskformer (fp16 case)

* Add integration test for oneformer (fp16 case)

* Remove slow decorator from fp16 tests

* Fix lint

* Remove usage of full inference and value checks for fp16

* Temporarily comment slow for {mask, mask2, one}former

* Add fp16 support to oneformer

* Revert "Temporarily comment slow for {mask, mask2, one}former"

This reverts commit e5371edabd301cf56079def0421a0a87df307cb0.

* Remove dtype conversion noop
This commit is contained in:
Pedro Lira
2023-08-07 16:07:29 -03:00
committed by GitHub
parent 5ee9693a1c
commit 080a97119c
6 changed files with 97 additions and 37 deletions

View File

@@ -21,7 +21,14 @@ import numpy as np
from tests.test_modeling_common import floats_tensor
from transformers import Mask2FormerConfig, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
from transformers.testing_utils import (
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
require_vision,
slow,
torch_device,
)
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@@ -420,6 +427,20 @@ class Mask2FormerModelIntegrationTest(unittest.TestCase):
).to(torch_device)
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
@require_torch_gpu
def test_inference_fp16(self):
model = (
Mask2FormerForUniversalSegmentation.from_pretrained(self.model_checkpoints)
.to(torch_device, dtype=torch.float16)
.eval()
)
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(image, return_tensors="pt").to(torch_device, dtype=torch.float16)
with torch.no_grad():
_ = model(**inputs)
def test_with_segmentation_maps_and_loss(self):
model = Mask2FormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
image_processor = self.default_image_processor