[FA-2] Fix fa-2 issue when passing config to from_pretrained (#28043)
* fix fa-2 issue * fix test * Update src/transformers/modeling_utils.py Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> * clenaer fix * up * add more robust tests * Update src/transformers/modeling_utils.py Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> * fixup * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * pop * add test --------- Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -1823,6 +1823,16 @@ class TestAttentionImplementation(unittest.TestCase):
|
||||
|
||||
self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))
|
||||
|
||||
def test_error_no_flash_available_with_config(self):
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
config = AutoConfig.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel")
|
||||
|
||||
_ = AutoModel.from_pretrained(
|
||||
"hf-tiny-model-private/tiny-random-MCTCTModel", config=config, attn_implementation="flash_attention_2"
|
||||
)
|
||||
|
||||
self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))
|
||||
|
||||
def test_error_wrong_attn_implementation(self):
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="foo")
|
||||
@@ -1840,6 +1850,21 @@ class TestAttentionImplementation(unittest.TestCase):
|
||||
|
||||
self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))
|
||||
|
||||
def test_not_available_flash_with_config(self):
|
||||
if is_flash_attn_2_available():
|
||||
self.skipTest("Please uninstall flash-attn package to run test_not_available_flash")
|
||||
|
||||
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTBigCodeModel")
|
||||
|
||||
with self.assertRaises(ImportError) as cm:
|
||||
_ = AutoModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-GPTBigCodeModel",
|
||||
config=config,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
|
||||
self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))
|
||||
|
||||
def test_not_available_sdpa(self):
|
||||
if is_torch_sdpa_available():
|
||||
self.skipTest("This test requires torch<=2.0")
|
||||
|
||||
Reference in New Issue
Block a user