Respect the config's attn_implementation if set (#32383)

* Respect the config's attn if set

* Update test - can override in from_config

* Fix
This commit is contained in:
amyeroberts
2024-08-05 16:33:19 +01:00
committed by GitHub
parent 458b0cd2c5
commit 7e5d46ded4
2 changed files with 63 additions and 1 deletions

View File

@@ -574,6 +574,60 @@ class ModelUtilsTest(TestCasePlus):
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
)
def test_model_from_config_attn_implementation(self):
# test that the model can be instantiated with attn_implementation of either
# 1. config created with explicit attn_implementatation and from_config
# 2. explicit from_config's attn_implementation argument with a config argument
# 3. config created with explicit attn_implementatation and from_config overriding with explicit attn_implementation argument
attn_implementation_available = ["eager"]
if is_torch_sdpa_available():
attn_implementation_available.append("sdpa")
if is_flash_attn_2_available():
attn_implementation_available.append("flash_attention_2")
mistral_attention_classes = {
"eager": "MistralAttention",
"sdpa": "MistralSdpaAttention",
"flash_attention_2": "MistralFlashAttention2",
}
for requested_attn_implementation in attn_implementation_available:
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation)
# Ensure the config was set correctly
self.assertEqual(config._attn_implementation, requested_attn_implementation)
self.assertEqual(config._attn_implementation_internal, requested_attn_implementation)
model = AutoModelForCausalLM.from_config(config)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
for module in model.modules():
if "Attention" in module.__class__.__name__:
self.assertEqual(
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
)
config = AutoConfig.from_pretrained(TINY_MISTRAL)
# When the config is not set, the default is "eager"
self.assertEqual(config._attn_implementation, "eager")
self.assertEqual(config._attn_implementation_internal, None)
model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
for module in model.modules():
if "Attention" in module.__class__.__name__:
self.assertEqual(
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
)
# Set a nonsense attn_implementation in the config, which should be overridden by the explicit argument
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation="foo-bar-baz")
self.assertEqual(config._attn_implementation, "foo-bar-baz")
self.assertEqual(config._attn_implementation_internal, "foo-bar-baz")
model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation)
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
for module in model.modules():
if "Attention" in module.__class__.__name__:
self.assertEqual(
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
)
def test_torch_dtype_byte_sizes(self):
torch_dtypes_and_bytes = [
(torch.double, 8),