Add sdpa for Beit (#34941)

* Add sdpa for Beit

* Updates

* [run-slow] beit

* Update inference benchmarks

* Update

* Fix - add missed to super().forward()

* Updates

* Fix missing import
This commit is contained in:
Omar Salman
2024-12-17 18:44:47 +05:00
committed by GitHub
parent 6c08b3b6e5
commit 747f361da1
7 changed files with 649 additions and 9 deletions

View File

@@ -14,18 +14,35 @@
# limitations under the License.
"""Testing suite for the PyTorch BEiT model."""
import inspect
import tempfile
import unittest
import numpy as np
from datasets import load_dataset
from packaging import version
from parameterized import parameterized
from transformers import BeitConfig
from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
from transformers.testing_utils import (
require_torch,
require_torch_multi_gpu,
require_torch_sdpa,
require_vision,
slow,
torch_device,
)
from transformers.utils import (
cached_property,
is_torch_available,
is_torch_bf16_available_on_device,
is_torch_fp16_available_on_device,
is_vision_available,
)
from ...test_backbone_common import BackboneTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor, sdpa_kernel
from ...test_pipeline_mixin import PipelineTesterMixin
@@ -74,6 +91,8 @@ class BeitModelTester:
scope=None,
out_indices=[1, 2, 3, 4],
out_features=["stage1", "stage2", "stage3", "stage4"],
attn_implementation="eager",
mask_ratio=0.5,
):
self.parent = parent
self.vocab_size = vocab_size
@@ -100,6 +119,8 @@ class BeitModelTester:
# in BeiT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 1
self.num_masks = int(mask_ratio * self.seq_length)
self.attn_implementation = attn_implementation
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@@ -131,6 +152,7 @@ class BeitModelTester:
initializer_range=self.initializer_range,
out_indices=self.out_indices,
out_features=self.out_features,
attn_implementation=self.attn_implementation,
)
def create_and_check_model(self, config, pixel_values, labels, pixel_labels):
@@ -387,6 +409,193 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
model = BeitModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
# The common test modifies the num_hidden_layers to be 1. However, for Beit we want to
# avoid that because the num_hidden_layers is generally assumed to be 4. Also, the code
# related to attention masks in the original common tests is not required as the Beit
# model does not handle attention masks. Furthermore, some extra code like modifying
# the norm layers eps values for specialized configs and checking for the 'noise'
# has been omitted to simply the test.
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
self.skipTest(
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
)
# Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
if torch_dtype == "float16":
torch_dtype = torch.float16
elif torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif torch_dtype == "float32":
torch_dtype = torch.float32
atols = {
("cpu", False, torch.float32): 1e-6,
("cpu", False, torch.float16): 5e-3,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-6,
("cpu", True, torch.float16): 5e-3,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-6,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-6,
("cuda", True, torch.bfloat16): 1e-2,
("cuda", True, torch.float16): 5e-3,
}
rtols = {
("cpu", False, torch.float32): 1e-4,
("cpu", False, torch.float16): 5e-3,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-4,
("cpu", True, torch.float16): 5e-3,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-4,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-4,
("cuda", True, torch.bfloat16): 3e-2,
("cuda", True, torch.float16): 5e-3,
}
def get_mean_reldiff(failcase, x, ref, atol, rtol):
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.rms_norm_eps = 1.0
config.layer_norm_eps = 1.0
config.norm_eps = 1.0
config.norm_epsilon = 1.0
config.layer_norm_epsilon = 1.0
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype, use_mask_token=True)
model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype)
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
attn_implementation="eager",
use_mask_token=True,
)
model_eager = model_eager.eval().to(torch_device, dtype=torch_dtype)
# Another way to make sure norm layers have desired epsilon. (Some models don't set it from its config.)
for x in model_eager.modules():
if isinstance(x, (nn.LayerNorm, nn.GroupNorm)):
x.eps = 1.0
for x in model_sdpa.modules():
if isinstance(x, (nn.LayerNorm, nn.GroupNorm)):
x.eps = 1.0
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
for padding_side in ["left", "right"]:
for use_mask in [False, True]:
for output_attentions in [True, False]:
can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters
if not (self.has_attentions and can_output_attn) and output_attentions:
continue
# TODO: if we can also check with `batch_size=1` without being flaky?
for batch_size in [7]:
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
dummy_input = dummy_input.to(torch_dtype)
dummy_input = dummy_input[:batch_size]
for enable_kernels in [False, True]:
failcase = f"padding_side={padding_side}, use_mask={use_mask}, enable_kernels={enable_kernels}"
processed_inputs = {
model.main_input_name: dummy_input,
"output_hidden_states": True,
}
if (
self.has_attentions
and "output_attentions" in inspect.signature(model_sdpa.forward).parameters
):
processed_inputs["output_attentions"] = output_attentions
if "bool_masked_pos" in inspect.signature(model_eager.forward).parameters:
dummy_mask = torch.ones((self.model_tester.num_masks,))
mask_length = self.model_tester.seq_length - 1 - dummy_mask.size(0)
dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)])
dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)
with torch.no_grad():
with sdpa_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
):
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
outputs_eager = model_eager(**prepared_inputs)
outputs_sdpa = model_sdpa(**prepared_inputs)
logits_eager = outputs_eager.hidden_states[-1]
logits_sdpa = outputs_sdpa.hidden_states[-1]
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
# Masked tokens output slightly deviates - we don't mind that.
if use_mask:
_logits_sdpa = torch.zeros_like(input=logits_sdpa)
_logits_eager = torch.zeros_like(input=logits_eager)
_logits_sdpa[:-1] = logits_sdpa[:-1]
_logits_eager[:-1] = logits_eager[:-1]
if padding_side == "left":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:]
_logits_eager[-1:, 2:] = logits_eager[-1:, 2:]
elif padding_side == "right":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2]
_logits_eager[-1:, 2:] = logits_eager[-1:, :-2]
logits_sdpa = _logits_sdpa
logits_eager = _logits_eager
results = [
torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol)
for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager)
]
# If 80% batch elements have matched results, it's fine
if np.mean(results) < 0.8:
fail_cases.append(
get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
)
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
# We will verify our results on an image of cute cats
def prepare_img():

View File

@@ -14,14 +14,32 @@
# limitations under the License.
"""Testing suite for the PyTorch Data2VecVision model."""
import inspect
import tempfile
import unittest
import numpy as np
from parameterized import parameterized
from transformers import Data2VecVisionConfig
from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
from transformers.testing_utils import (
require_torch,
require_torch_multi_gpu,
require_torch_sdpa,
require_vision,
slow,
torch_device,
)
from transformers.utils import (
cached_property,
is_torch_available,
is_torch_bf16_available_on_device,
is_torch_fp16_available_on_device,
is_vision_available,
)
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor, sdpa_kernel
from ...test_pipeline_mixin import PipelineTesterMixin
@@ -66,6 +84,8 @@ class Data2VecVisionModelTester:
num_labels=3,
scope=None,
out_indices=[0, 1, 2, 3],
attn_implementation="eager",
mask_ratio=0.5,
):
self.parent = parent
self.vocab_size = 100
@@ -91,6 +111,8 @@ class Data2VecVisionModelTester:
# in BeiT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches + 1
self.num_masks = int(mask_ratio * self.seq_length)
self.attn_implementation = attn_implementation
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
@@ -121,6 +143,7 @@ class Data2VecVisionModelTester:
is_decoder=False,
initializer_range=self.initializer_range,
out_indices=self.out_indices,
attn_implementation=self.attn_implementation,
)
def create_and_check_model(self, config, pixel_values, labels, pixel_labels):
@@ -300,6 +323,194 @@ class Data2VecVisionModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
model = Data2VecVisionModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
# Copied from tests.models.beit.test_modeling_beit.BeitModelTest.test_eager_matches_sdpa_inference with Beit->Data2VecVision
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
# The common test modifies the num_hidden_layers to be 1. However, for Data2VecVision we want to
# avoid that because the num_hidden_layers is generally assumed to be 4. Also, the code
# related to attention masks in the original common tests is not required as the Data2VecVision
# model does not handle attention masks. Furthermore, some extra code like modifying
# the norm layers eps values for specialized configs and checking for the 'noise'
# has been omitted to simply the test.
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not self.all_model_classes[0]._supports_sdpa:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device):
self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)")
if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device):
self.skipTest(
f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)"
)
# Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead.
if torch_dtype == "float16":
torch_dtype = torch.float16
elif torch_dtype == "bfloat16":
torch_dtype = torch.bfloat16
elif torch_dtype == "float32":
torch_dtype = torch.float32
atols = {
("cpu", False, torch.float32): 1e-6,
("cpu", False, torch.float16): 5e-3,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-6,
("cpu", True, torch.float16): 5e-3,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-6,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-6,
("cuda", True, torch.bfloat16): 1e-2,
("cuda", True, torch.float16): 5e-3,
}
rtols = {
("cpu", False, torch.float32): 1e-4,
("cpu", False, torch.float16): 5e-3,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-4,
("cpu", True, torch.float16): 5e-3,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-4,
("cuda", False, torch.bfloat16): 1e-2,
("cuda", False, torch.float16): 5e-3,
("cuda", True, torch.float32): 1e-4,
("cuda", True, torch.bfloat16): 3e-2,
("cuda", True, torch.float16): 5e-3,
}
def get_mean_reldiff(failcase, x, ref, atol, rtol):
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.rms_norm_eps = 1.0
config.layer_norm_eps = 1.0
config.norm_eps = 1.0
config.norm_epsilon = 1.0
config.layer_norm_epsilon = 1.0
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype, use_mask_token=True)
model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype)
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
attn_implementation="eager",
use_mask_token=True,
)
model_eager = model_eager.eval().to(torch_device, dtype=torch_dtype)
# Another way to make sure norm layers have desired epsilon. (Some models don't set it from its config.)
for x in model_eager.modules():
if isinstance(x, (nn.LayerNorm, nn.GroupNorm)):
x.eps = 1.0
for x in model_sdpa.modules():
if isinstance(x, (nn.LayerNorm, nn.GroupNorm)):
x.eps = 1.0
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand
fail_cases = []
for padding_side in ["left", "right"]:
for use_mask in [False, True]:
for output_attentions in [True, False]:
can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters
if not (self.has_attentions and can_output_attn) and output_attentions:
continue
# TODO: if we can also check with `batch_size=1` without being flaky?
for batch_size in [7]:
dummy_input = inputs_dict[model.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
dummy_input = dummy_input.to(torch_dtype)
dummy_input = dummy_input[:batch_size]
for enable_kernels in [False, True]:
failcase = f"padding_side={padding_side}, use_mask={use_mask}, enable_kernels={enable_kernels}"
processed_inputs = {
model.main_input_name: dummy_input,
"output_hidden_states": True,
}
if (
self.has_attentions
and "output_attentions" in inspect.signature(model_sdpa.forward).parameters
):
processed_inputs["output_attentions"] = output_attentions
if "bool_masked_pos" in inspect.signature(model_eager.forward).parameters:
dummy_mask = torch.ones((self.model_tester.num_masks,))
mask_length = self.model_tester.seq_length - 1 - dummy_mask.size(0)
dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)])
dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device)
with torch.no_grad():
with sdpa_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
):
prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
outputs_eager = model_eager(**prepared_inputs)
outputs_sdpa = model_sdpa(**prepared_inputs)
logits_eager = outputs_eager.hidden_states[-1]
logits_sdpa = outputs_sdpa.hidden_states[-1]
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
# Masked tokens output slightly deviates - we don't mind that.
if use_mask:
_logits_sdpa = torch.zeros_like(input=logits_sdpa)
_logits_eager = torch.zeros_like(input=logits_eager)
_logits_sdpa[:-1] = logits_sdpa[:-1]
_logits_eager[:-1] = logits_eager[:-1]
if padding_side == "left":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:]
_logits_eager[-1:, 2:] = logits_eager[-1:, 2:]
elif padding_side == "right":
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2]
_logits_eager[-1:, 2:] = logits_eager[-1:, :-2]
logits_sdpa = _logits_sdpa
logits_eager = _logits_eager
results = [
torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol)
for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager)
]
# If 80% batch elements have matched results, it's fine
if np.mean(results) < 0.8:
fail_cases.append(
get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
)
self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
# We will verify our results on an image of cute cats
def prepare_img():