Add accelerate support for ViT family (#20174)

* add `accelerate` support for `ViT` family

- add `_no_split_modules`
- manually cast to the right `dtype`: to change

* enable `float16` for `deit`

* fix `make fixup`

* add `slow` test for `fp16` inference

* another safety check

* Update src/transformers/models/deit/modeling_deit.py
This commit is contained in:
Younes Belkada
2022-11-15 11:06:01 +01:00
committed by GitHub
parent 11b2e45ccc
commit f1e8c48c5e
4 changed files with 82 additions and 6 deletions

View File

@@ -21,7 +21,14 @@ import warnings
from transformers import DeiTConfig
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.testing_utils import (
require_accelerate,
require_torch,
require_torch_gpu,
require_vision,
slow,
torch_device,
)
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@@ -394,3 +401,23 @@ class DeiTModelIntegrationTest(unittest.TestCase):
expected_slice = torch.tensor([-1.0266, 0.1912, -1.2861]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
@slow
@require_accelerate
@require_torch_gpu
def test_inference_fp16(self):
r"""
A small test to make sure that inference work in half precision without any problem.
"""
model = DeiTModel.from_pretrained(
"facebook/deit-base-distilled-patch16-224", torch_dtype=torch.float16, device_map="auto"
)
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="pt")
pixel_values = inputs.pixel_values.to(torch_device)
# forward pass to make sure inference works in fp16
with torch.no_grad():
_ = model(pixel_values)