Kernels flash attn (#39474)
* use partial to wrap around `transformers` utils! * try to refactor? * revert one wrong change * just a nit * push * reverter watever was wrong! * some nits * fixes when there is no attention mask * bring the licence back * some fixes * nit * style * remove prints * correct dtype * fa flags for testing * update * use paged attention if requested! * updates * a clone was needed, not sure why * automatically create cu seq lens when input is flash, this at least makes sure layers don't re-compute * simplify and improve? * flash attention is kinda broken on recent cuda version so allow the opportunity to use something else * fix! * protect kernels import * update * properly parse generation config being passed * revert and update * add two tests * some fixes * fix test FA2 * takes comment into account * fixup * revert changes * revert the clone, it is only needed because the metal kernel is not doing it? * [docs] update attention implementation and cache docs (#39547) * update docs * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * applu suggestions --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * fix mps on our side for now * Update src/transformers/integrations/flash_paged.py * no qa --------- Co-authored-by: Vasqu <antonprogamer@gmail.com> Co-authored-by: Raushan Turganbay <raushan@huggingface.co> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
@@ -86,12 +86,14 @@ from transformers.testing_utils import (
|
||||
require_deepspeed,
|
||||
require_flash_attn,
|
||||
require_flash_attn_3,
|
||||
require_kernels,
|
||||
require_non_hpu,
|
||||
require_safetensors,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_gpu,
|
||||
require_torch_greater_or_equal,
|
||||
require_torch_mps,
|
||||
require_torch_multi_accelerator,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_sdpa,
|
||||
@@ -3474,94 +3476,107 @@ class ModelTesterMixin:
|
||||
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.head_dim = 64 # fa2 does not always support arbitrary headim
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
model.to(torch_device)
|
||||
dummy_input = inputs_dict[model.main_input_name][:1]
|
||||
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||
dummy_input = dummy_input.to(torch.bfloat16)
|
||||
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
|
||||
model.to(torch_device)
|
||||
|
||||
dummy_input = inputs_dict[model.main_input_name][:1]
|
||||
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||
dummy_input = dummy_input.to(torch.bfloat16)
|
||||
|
||||
dummy_attention_mask = inputs_dict.get("attention_mask", None)
|
||||
|
||||
if dummy_attention_mask is not None:
|
||||
dummy_attention_mask = dummy_attention_mask[:1]
|
||||
if padding_side == "left":
|
||||
dummy_attention_mask[:, 1:] = 1
|
||||
dummy_attention_mask[:, :1] = 0
|
||||
else:
|
||||
dummy_attention_mask[:, :-1] = 1
|
||||
dummy_attention_mask[:, -1:] = 0
|
||||
if model.config.is_encoder_decoder:
|
||||
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
|
||||
|
||||
outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
|
||||
outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
|
||||
else:
|
||||
outputs = model(dummy_input, output_hidden_states=True)
|
||||
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
|
||||
|
||||
logits = (
|
||||
outputs.hidden_states[-1]
|
||||
if not model.config.is_encoder_decoder
|
||||
else outputs.decoder_hidden_states[-1]
|
||||
)
|
||||
logits_fa = (
|
||||
outputs_fa.hidden_states[-1]
|
||||
if not model.config.is_encoder_decoder
|
||||
else outputs_fa.decoder_hidden_states[-1]
|
||||
)
|
||||
|
||||
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
other_inputs = {
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": dummy_attention_mask,
|
||||
"output_hidden_states": True,
|
||||
}
|
||||
if dummy_attention_mask is not None:
|
||||
other_inputs["attention_mask"] = dummy_attention_mask
|
||||
|
||||
outputs = model(dummy_input, **other_inputs)
|
||||
outputs_fa = model_fa(dummy_input, **other_inputs)
|
||||
else:
|
||||
other_inputs = {
|
||||
"output_hidden_states": True,
|
||||
}
|
||||
if dummy_attention_mask is not None:
|
||||
other_inputs["attention_mask"] = dummy_attention_mask
|
||||
|
||||
outputs = model(dummy_input, **other_inputs)
|
||||
outputs_fa = model_fa(dummy_input, **other_inputs)
|
||||
|
||||
logits = (
|
||||
outputs.hidden_states[-1]
|
||||
if not model.config.is_encoder_decoder
|
||||
else outputs.decoder_hidden_states[-1]
|
||||
)
|
||||
logits_fa = (
|
||||
outputs_fa.hidden_states[-1]
|
||||
if not model.config.is_encoder_decoder
|
||||
else outputs_fa.decoder_hidden_states[-1]
|
||||
)
|
||||
dummy_attention_mask = inputs_dict.get("attention_mask", None)
|
||||
|
||||
if dummy_attention_mask is not None:
|
||||
dummy_attention_mask = dummy_attention_mask[:1]
|
||||
if padding_side == "left":
|
||||
assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
|
||||
|
||||
# check with inference + dropout
|
||||
model.train()
|
||||
_ = model_fa(dummy_input, **other_inputs)
|
||||
dummy_attention_mask[:, 1:] = 1
|
||||
dummy_attention_mask[:, :1] = 0
|
||||
else:
|
||||
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
|
||||
dummy_attention_mask[:, :-1] = 1
|
||||
dummy_attention_mask[:, -1:] = 0
|
||||
if model.config.is_encoder_decoder:
|
||||
decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
|
||||
|
||||
outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
|
||||
model.set_attn_implementation(attn_implementation)
|
||||
outputs_fa = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
|
||||
else:
|
||||
outputs = model(dummy_input, output_hidden_states=True)
|
||||
model.set_attn_implementation(attn_implementation)
|
||||
outputs_fa = model(dummy_input, output_hidden_states=True)
|
||||
|
||||
model.set_attn_implementation("sdpa")
|
||||
logits = (
|
||||
outputs.hidden_states[-1] if not model.config.is_encoder_decoder else outputs.decoder_hidden_states[-1]
|
||||
)
|
||||
logits_fa = (
|
||||
outputs_fa.hidden_states[-1]
|
||||
if not model.config.is_encoder_decoder
|
||||
else outputs_fa.decoder_hidden_states[-1]
|
||||
)
|
||||
|
||||
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
other_inputs = {
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": dummy_attention_mask,
|
||||
"output_hidden_states": True,
|
||||
}
|
||||
if dummy_attention_mask is not None:
|
||||
other_inputs["attention_mask"] = dummy_attention_mask
|
||||
|
||||
outputs = model(dummy_input, **other_inputs)
|
||||
model.set_attn_implementation(attn_implementation)
|
||||
outputs_fa = model(dummy_input, **other_inputs)
|
||||
else:
|
||||
other_inputs = {
|
||||
"output_hidden_states": True,
|
||||
}
|
||||
if dummy_attention_mask is not None:
|
||||
other_inputs["attention_mask"] = dummy_attention_mask
|
||||
|
||||
outputs = model(dummy_input, **other_inputs)
|
||||
model.set_attn_implementation(attn_implementation)
|
||||
outputs_fa = model(dummy_input, **other_inputs)
|
||||
|
||||
model.set_attn_implementation("sdpa")
|
||||
logits = (
|
||||
outputs.hidden_states[-1] if not model.config.is_encoder_decoder else outputs.decoder_hidden_states[-1]
|
||||
)
|
||||
logits_fa = (
|
||||
outputs_fa.hidden_states[-1]
|
||||
if not model.config.is_encoder_decoder
|
||||
else outputs_fa.decoder_hidden_states[-1]
|
||||
)
|
||||
|
||||
if padding_side == "left":
|
||||
assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
|
||||
|
||||
# check with inference + dropout
|
||||
model.train()
|
||||
model.set_attn_implementation(attn_implementation)
|
||||
_ = model(dummy_input, **other_inputs)
|
||||
else:
|
||||
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
|
||||
|
||||
@require_kernels
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_flash_attn_kernels_inference_equivalence(self):
|
||||
self.flash_attn_inference_equivalence(attn_implementation="kernels-community/flash-attn3", padding_side="left")
|
||||
|
||||
@require_torch_mps
|
||||
@require_kernels
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
@is_flaky()
|
||||
def test_flash_attn_kernels_mps_inference_equivalence(self):
|
||||
self.flash_attn_inference_equivalence(
|
||||
attn_implementation="kernels-community/metal-flash-sdpa", padding_side="left"
|
||||
)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user