Flash Attention 2 support for RoCm (#27611)

* support FA2

* fix typo

* fix broken tests

* fix more test errors

* left/right

* fix bug

* more test

* typo

* fix layout flash attention falcon

* do not support this case

* use allclose instead of equal

* fix various bugs with flash attention

* bump

* fix test

* fix mistral

* use skiptest instead of return that may be misleading

* add fix causal arg flash attention

* fix copies

* more explicit comment

* still use self.is_causal

* fix causal argument

* comment

* fixes

* update documentation

* add link

* wrong test

* simplify FA2 RoCm requirements

* update opt

* make flash_attn_uses_top_left_mask attribute private and precise comment

* better error handling

* fix copy & mistral

* Update src/transformers/modeling_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/utils/import_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* use is_flash_attn_greater_or_equal_2_10 instead of is_flash_attn_greater_or_equal_210

* fix merge

* simplify

* inline args

---------

Co-authored-by: Felix Marty <felix@hf.co>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
fxmarty
2023-12-04 13:52:17 +01:00
committed by GitHub
parent 4d4febb7aa
commit 1da1302ec8
17 changed files with 253 additions and 51 deletions

View File

@@ -3087,7 +3087,7 @@ class ModelTesterMixin:
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
self.assertTrue(torch.equal(out, out_fa))
self.assertTrue(torch.allclose(out, out_fa))
@require_flash_attn
@require_torch_gpu
@@ -3130,7 +3130,7 @@ class ModelTesterMixin:
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
self.assertTrue(torch.equal(out, out_fa))
self.assertTrue(torch.allclose(out, out_fa))
@require_flash_attn
@require_torch_gpu