Add mask2former fp16 support (#25093)

* Add mask2former fp16 support

* Clear consistency/quality issues

* Fix consistency/quality (2)

* Add integration test for mask2former (fp16 case)

* Fix code quality

* Add integration test for maskformer (fp16 case)

* Add integration test for oneformer (fp16 case)

* Remove slow decorator from fp16 tests

* Fix lint

* Remove usage of full inference and value checks for fp16

* Temporarily comment slow for {mask, mask2, one}former

* Add fp16 support to oneformer

* Revert "Temporarily comment slow for {mask, mask2, one}former"

This reverts commit e5371edabd301cf56079def0421a0a87df307cb0.

* Remove dtype conversion noop
This commit is contained in:
Pedro Lira
2023-08-07 16:07:29 -03:00
committed by GitHub
parent 5ee9693a1c
commit 080a97119c
6 changed files with 97 additions and 37 deletions

View File

@@ -22,7 +22,14 @@ import numpy as np
from tests.test_modeling_common import floats_tensor
from transformers import OneFormerConfig, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
from transformers.testing_utils import (
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
require_vision,
slow,
torch_device,
)
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
@@ -533,6 +540,20 @@ class OneFormerModelIntegrationTest(unittest.TestCase):
).to(torch_device)
self.assertTrue(torch.allclose(class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
@require_torch_gpu
def test_inference_fp16(self):
model = (
OneFormerForUniversalSegmentation.from_pretrained(self.model_checkpoints)
.to(torch_device, dtype=torch.float16)
.eval()
)
processor = self.default_processor
image = prepare_img()
inputs = processor(image, ["semantic"], return_tensors="pt").to(torch_device, dtype=torch.float16)
with torch.no_grad():
_ = model(**inputs)
def test_with_segmentation_maps_and_loss(self):
dummy_model = OneFormerForUniversalSegmentation.from_pretrained(self.model_checkpoints)
processor = self.default_processor