Add accelerate support for ViT family (#20174)
* add `accelerate` support for `ViT` family - add `_no_split_modules` - manually cast to the right `dtype`: to change * enable `float16` for `deit` * fix `make fixup` * add `slow` test for `fp16` inference * another safety check * Update src/transformers/models/deit/modeling_deit.py
This commit is contained in:
@@ -19,7 +19,14 @@ import inspect
|
||||
import unittest
|
||||
|
||||
from transformers import ViTConfig
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
require_accelerate,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -300,3 +307,21 @@ class ViTModelIntegrationTest(unittest.TestCase):
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
@require_accelerate
|
||||
@require_torch_gpu
|
||||
def test_inference_fp16(self):
|
||||
r"""
|
||||
A small test to make sure that inference work in half precision without any problem.
|
||||
"""
|
||||
model = ViTModel.from_pretrained("facebook/dino-vits8", torch_dtype=torch.float16, device_map="auto")
|
||||
feature_extractor = self.default_feature_extractor
|
||||
|
||||
image = prepare_img()
|
||||
inputs = feature_extractor(images=image, return_tensors="pt")
|
||||
pixel_values = inputs.pixel_values.to(torch_device)
|
||||
|
||||
# forward pass to make sure inference works in fp16
|
||||
with torch.no_grad():
|
||||
_ = model(pixel_values)
|
||||
|
||||
Reference in New Issue
Block a user