Fix SDPA sliding window compatibility (#30127)
* fix sdpa + sliding window * give credit Co-authored-by: ehuaa <ehuamail@163.com> * remove unnecessary warning * fix typog * add test --------- Co-authored-by: ehuaa <ehuamail@163.com>
This commit is contained in:
@@ -3841,6 +3841,57 @@ class ModelTesterMixin:
|
||||
|
||||
self.assertTrue(torch.allclose(res_eager, res_sdpa))
|
||||
|
||||
@require_torch_sdpa
|
||||
def test_sdpa_matches_eager_sliding_window(self):
|
||||
WINDOW_ATTENTION_MODELS = ["mistral", "mixtral", "qwen2", "qwen_moe", "starcoder2"]
|
||||
|
||||
if len(self.all_generative_model_classes) == 0:
|
||||
self.skipTest(f"No generative model classes for {self.__class__.__name__}")
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
if config.model_type not in WINDOW_ATTENTION_MODELS:
|
||||
self.skipTest(f"{config.model_type} does not use window attention")
|
||||
|
||||
config.sliding_window = 2
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
attention_mask = inputs_dict["attention_mask"]
|
||||
|
||||
self.assertTrue(dummy_input.ndim == 2)
|
||||
self.assertTrue(dummy_input.shape[1] > 6)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
with torch.device(torch_device):
|
||||
model_eager = AutoModelForCausalLM.from_config(
|
||||
config, attn_implementation="eager", torch_dtype=torch.float32
|
||||
)
|
||||
|
||||
model_eager.save_pretrained(tmpdir)
|
||||
|
||||
with torch.device(torch_device):
|
||||
model_sdpa = AutoModelForCausalLM.from_pretrained(
|
||||
tmpdir, attn_implementation="sdpa", torch_dtype=torch.float32
|
||||
)
|
||||
|
||||
model_eager = model_eager.eval()
|
||||
model_sdpa = model_sdpa.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=False,
|
||||
enable_math=True,
|
||||
enable_mem_efficient=False,
|
||||
):
|
||||
res_eager = model_eager(**inputs_dict, return_dict=False)[0]
|
||||
res_sdpa = model_sdpa(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Only non-padding tokens are expected to match.
|
||||
self.assertTrue(
|
||||
torch.allclose(res_eager[attention_mask == 1], res_sdpa[attention_mask == 1], rtol=1e-3)
|
||||
)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
|
||||
Reference in New Issue
Block a user