make test_eager_matches_sdpa_inference less flaky (#34512)
* try * try * try * try * try * try * update * update * update * update * update * update * update --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -1263,6 +1263,9 @@ class GenerationTesterMixin:
|
||||
|
||||
if model.get_output_embeddings() is None:
|
||||
self.skipTest("DoLa is not supported for models that don't have output embeddings")
|
||||
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True, config=model.config)
|
||||
|
||||
# Sets dola generation arguments such that:
|
||||
# a) no EOS is generated, to ensure generation doesn't break early
|
||||
# b) there are at least two forward passes in the main model, to ensure the input preparation of
|
||||
@@ -1280,7 +1283,7 @@ class GenerationTesterMixin:
|
||||
"use_cache": getattr(config, "use_cache", False), # Some models don't support the cache
|
||||
"dola_layers": "low",
|
||||
}
|
||||
output_dola = model.generate(**generation_kwargs, **inputs_dict)
|
||||
output_dola = model.generate(**generation_kwargs, **logits_processor_kwargs, **inputs_dict)
|
||||
self._check_outputs(output_dola, model.config, use_cache=getattr(config, "use_cache", False))
|
||||
|
||||
@pytest.mark.generate
|
||||
|
||||
@@ -85,7 +85,7 @@ class LlavaVisionText2TextModelTester:
|
||||
},
|
||||
is_training=True,
|
||||
vision_config={
|
||||
"image_size": 30,
|
||||
"image_size": 8,
|
||||
"patch_size": 2,
|
||||
"num_channels": 3,
|
||||
"is_training": True,
|
||||
@@ -118,9 +118,9 @@ class LlavaVisionText2TextModelTester:
|
||||
self.batch_size = 3
|
||||
self.num_channels = 3
|
||||
self.image_size = 336
|
||||
self.encoder_seq_length = 232
|
||||
self.num_image_tokens = 225
|
||||
self.num_image_tokens = (self.vision_config["image_size"] // self.vision_config["patch_size"]) ** 2
|
||||
self.seq_length = seq_length + self.num_image_tokens
|
||||
self.encoder_seq_length = self.seq_length
|
||||
|
||||
def get_config(self):
|
||||
return LlavaConfig(
|
||||
|
||||
@@ -85,7 +85,7 @@ class VipLlavaVisionText2TextModelTester:
|
||||
is_training=True,
|
||||
vision_config={
|
||||
"batch_size": 12,
|
||||
"image_size": 30,
|
||||
"image_size": 8,
|
||||
"patch_size": 2,
|
||||
"num_channels": 3,
|
||||
"is_training": True,
|
||||
@@ -117,9 +117,9 @@ class VipLlavaVisionText2TextModelTester:
|
||||
self.batch_size = 3
|
||||
self.num_channels = 3
|
||||
self.image_size = 336
|
||||
self.encoder_seq_length = 232
|
||||
self.num_image_tokens = 225
|
||||
self.num_image_tokens = (self.vision_config["image_size"] // self.vision_config["patch_size"]) ** 2
|
||||
self.seq_length = seq_length + self.num_image_tokens
|
||||
self.encoder_seq_length = self.seq_length
|
||||
|
||||
def get_config(self):
|
||||
return VipLlavaConfig(
|
||||
|
||||
@@ -3982,6 +3982,13 @@ class ModelTesterMixin:
|
||||
def get_mean_reldiff(failcase, x, ref, atol, rtol):
|
||||
return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
|
||||
|
||||
if hasattr(self.model_tester, "num_hidden_layers"):
|
||||
self.model_tester.num_hidden_layers = 1
|
||||
if hasattr(self.model_tester, "vision_config") and "num_hidden_layers" in self.model_tester.vision_config:
|
||||
self.model_tester.vision_config["num_hidden_layers"] = 1
|
||||
if hasattr(self.model_tester, "text_config") and "num_hidden_layers" in self.model_tester.text_config:
|
||||
self.model_tester.text_config["num_hidden_layers"] = 1
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
@@ -4013,7 +4020,8 @@ class ModelTesterMixin:
|
||||
can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters
|
||||
if not (self.has_attentions and can_output_attn) and output_attentions:
|
||||
continue
|
||||
for batch_size in [1, 5]:
|
||||
# TODO: if we can also check with `batch_size=1` without being flaky?
|
||||
for batch_size in [7]:
|
||||
dummy_input = inputs_dict[model.main_input_name]
|
||||
|
||||
if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]:
|
||||
@@ -4064,14 +4072,14 @@ class ModelTesterMixin:
|
||||
|
||||
dummy_attention_mask[:] = 1
|
||||
if padding_side == "left":
|
||||
dummy_attention_mask[-1, :-1] = 1
|
||||
dummy_attention_mask[-1, -4:] = 0
|
||||
dummy_attention_mask[-1, :2] = 0
|
||||
dummy_attention_mask[-1, 2:] = 1
|
||||
elif padding_side == "right":
|
||||
dummy_attention_mask[-1, 1:] = 1
|
||||
dummy_attention_mask[-1, :3] = 0
|
||||
dummy_attention_mask[-1, -2:] = 0
|
||||
dummy_attention_mask[-1, :-2] = 1
|
||||
|
||||
for enable_kernels in [False, True]:
|
||||
failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}"
|
||||
failcase = f"padding_side={padding_side}, use_mask={use_mask}, enable_kernels={enable_kernels}"
|
||||
if is_encoder_decoder:
|
||||
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[
|
||||
:batch_size
|
||||
@@ -4161,49 +4169,29 @@ class ModelTesterMixin:
|
||||
|
||||
# Masked tokens output slightly deviates - we don't mind that.
|
||||
if use_mask:
|
||||
_logits_sdpa = torch.zeros_like(input=logits_sdpa)
|
||||
_logits_eager = torch.zeros_like(input=logits_eager)
|
||||
|
||||
_logits_sdpa[:-1] = logits_sdpa[:-1]
|
||||
_logits_eager[:-1] = logits_eager[:-1]
|
||||
|
||||
if padding_side == "left":
|
||||
sub_sdpa = logits_sdpa[:-1]
|
||||
sub_eager = logits_eager[:-1]
|
||||
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
|
||||
fail_cases.append(
|
||||
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
|
||||
)
|
||||
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:]
|
||||
_logits_eager[-1:, 2:] = logits_eager[-1:, 2:]
|
||||
|
||||
sub_sdpa = logits_sdpa[-1, :-4]
|
||||
sub_eager = logits_eager[-1, :-4]
|
||||
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
|
||||
fail_cases.append(
|
||||
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
|
||||
)
|
||||
|
||||
# Testing the padding tokens is not really meaningful but anyway
|
||||
# sub_sdpa = logits_sdpa[-1, -4:]
|
||||
# sub_eager = logits_eager[-1, -4:]
|
||||
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
|
||||
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
|
||||
elif padding_side == "right":
|
||||
sub_sdpa = logits_sdpa[:-1]
|
||||
sub_eager = logits_eager[:-1]
|
||||
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
|
||||
fail_cases.append(
|
||||
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
|
||||
)
|
||||
_logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2]
|
||||
_logits_eager[-1:, 2:] = logits_eager[-1:, :-2]
|
||||
|
||||
sub_sdpa = logits_sdpa[-1, 3:]
|
||||
sub_eager = logits_eager[-1, 3:]
|
||||
if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
|
||||
fail_cases.append(
|
||||
get_mean_reldiff(failcase, sub_sdpa, sub_eager, atol, rtol)
|
||||
)
|
||||
logits_sdpa = _logits_sdpa
|
||||
logits_eager = _logits_eager
|
||||
|
||||
# Testing the padding tokens is not really meaningful but anyway
|
||||
# sub_sdpa = logits_sdpa[-1, :3]
|
||||
# sub_eager = logits_eager[-1, :3]
|
||||
# if not torch.allclose(sub_sdpa, sub_eager, atol=atol, rtol=rtol):
|
||||
# fail_cases.append(get_mean_reldiff(failcase, sub_sdpa, sub_eager, 4e-2, 4e-2))
|
||||
|
||||
else:
|
||||
if not torch.allclose(logits_sdpa, logits_eager, atol=atol, rtol=rtol):
|
||||
results = [
|
||||
torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol)
|
||||
for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager)
|
||||
]
|
||||
# If 80% batch elements have matched results, it's fine
|
||||
if np.mean(results) < 0.8:
|
||||
fail_cases.append(
|
||||
get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user