Flash Attention 2 support for RoCm (#27611)
* support FA2 * fix typo * fix broken tests * fix more test errors * left/right * fix bug * more test * typo * fix layout flash attention falcon * do not support this case * use allclose instead of equal * fix various bugs with flash attention * bump * fix test * fix mistral * use skiptest instead of return that may be misleading * add fix causal arg flash attention * fix copies * more explicit comment * still use self.is_causal * fix causal argument * comment * fixes * update documentation * add link * wrong test * simplify FA2 RoCm requirements * update opt * make flash_attn_uses_top_left_mask attribute private and precise comment * better error handling * fix copy & mistral * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/utils/import_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * use is_flash_attn_greater_or_equal_2_10 instead of is_flash_attn_greater_or_equal_210 * fix merge * simplify * inline args --------- Co-authored-by: Felix Marty <felix@hf.co> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -3087,7 +3087,7 @@ class ModelTesterMixin:
|
||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
|
||||
)
|
||||
|
||||
self.assertTrue(torch.equal(out, out_fa))
|
||||
self.assertTrue(torch.allclose(out, out_fa))
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@@ -3130,7 +3130,7 @@ class ModelTesterMixin:
|
||||
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
|
||||
)
|
||||
|
||||
self.assertTrue(torch.equal(out, out_fa))
|
||||
self.assertTrue(torch.allclose(out, out_fa))
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user