Remove @slow for test_eager_matches_sdpa_inference (#34558)

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2024-11-05 16:10:42 +01:00
committed by GitHub
parent 082e57e0d4
commit f2d5dfbab2
21 changed files with 271 additions and 626 deletions

View File

@@ -3928,7 +3928,6 @@ class ModelTesterMixin:
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
@@ -3954,8 +3953,10 @@ class ModelTesterMixin:
atols = {
("cpu", False, torch.float32): 1e-6,
("cpu", False, torch.float16): 5e-3,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-6,
("cpu", True, torch.float16): 5e-3,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-6,
("cuda", False, torch.bfloat16): 1e-2,
@@ -3966,8 +3967,10 @@ class ModelTesterMixin:
}
rtols = {
("cpu", False, torch.float32): 1e-4,
("cpu", False, torch.float16): 5e-3,
("cpu", False, torch.bfloat16): 1e-2,
("cpu", True, torch.float32): 1e-4,
("cpu", True, torch.float16): 5e-3,
("cpu", True, torch.bfloat16): 1e-2,
("cuda", False, torch.float32): 1e-4,
("cuda", False, torch.bfloat16): 1e-2,
@@ -3983,12 +3986,31 @@ class ModelTesterMixin:
if hasattr(self.model_tester, "num_hidden_layers"):
self.model_tester.num_hidden_layers = 1
if hasattr(self.model_tester, "vision_config") and "num_hidden_layers" in self.model_tester.vision_config:
self.model_tester.vision_config = copy.deepcopy(self.model_tester.vision_config)
self.model_tester.vision_config["num_hidden_layers"] = 1
if hasattr(self.model_tester, "text_config") and "num_hidden_layers" in self.model_tester.text_config:
self.model_tester.text_config = copy.deepcopy(self.model_tester.text_config)
self.model_tester.text_config["num_hidden_layers"] = 1
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.rms_norm_eps = 1.0
config.layer_norm_eps = 1.0
config.norm_eps = 1.0
config.norm_epsilon = 1.0
config.layer_norm_epsilon = 1.0
# norm layers (layer/group norm, etc.) could cause flaky tests when the tensors have very small variance.
# (We don't need the original epsilon values to check eager/sdpa matches)
for attr in ["text_config", "vision_config", "text_encoder", "audio_encoder", "decoder"]:
if hasattr(config, attr):
getattr(config, attr).rms_norm_eps = 1.0
getattr(config, attr).layer_norm_eps = 1.0
getattr(config, attr).norm_eps = 1.0
getattr(config, attr).norm_epsilon = 1.0
getattr(config, attr).layer_norm_epsilon = 1.0
model = model_class(config)
# FIXME: we deactivate boolean mask for models using "use_mask_token" in their constructors.
# These models support masking only in the case `use_mask_token=True`. Otherwise they cannot consume an input mask.
@@ -4000,14 +4022,22 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype)
model_sdpa = model_sdpa.eval().to(torch_device)
model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype)
model_eager = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch_dtype,
attn_implementation="eager",
)
model_eager = model_eager.eval().to(torch_device)
model_eager = model_eager.eval().to(torch_device, dtype=torch_dtype)
# Another way to make sure norm layers have desired epsilon. (Some models don't set it from its config.)
for x in model_eager.modules():
if isinstance(x, (nn.LayerNorm, nn.GroupNorm)):
x.eps = 1.0
for x in model_sdpa.modules():
if isinstance(x, (nn.LayerNorm, nn.GroupNorm)):
x.eps = 1.0
# We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model,
# but it would be nicer to have an efficient way to use parameterized.expand