Respect the config's attn_implementation if set (#32383)
* Respect the config's attn if set * Update test - can override in from_config * Fix
This commit is contained in:
@@ -1454,7 +1454,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
||||
|
||||
config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config.
|
||||
config._attn_implementation = kwargs.pop("attn_implementation", None)
|
||||
|
||||
if config._attn_implementation_internal is not None:
|
||||
# In this case, the config has been created with the attn_implementation set by the user, which we
|
||||
# should respect.
|
||||
attn_implementation = config._attn_implementation_internal
|
||||
else:
|
||||
attn_implementation = None
|
||||
|
||||
config._attn_implementation = kwargs.pop("attn_implementation", attn_implementation)
|
||||
config = cls._autoset_attn_implementation(
|
||||
config,
|
||||
use_flash_attention_2=use_flash_attention_2,
|
||||
|
||||
@@ -574,6 +574,60 @@ class ModelUtilsTest(TestCasePlus):
|
||||
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
|
||||
)
|
||||
|
||||
def test_model_from_config_attn_implementation(self):
|
||||
# test that the model can be instantiated with attn_implementation of either
|
||||
# 1. config created with explicit attn_implementatation and from_config
|
||||
# 2. explicit from_config's attn_implementation argument with a config argument
|
||||
# 3. config created with explicit attn_implementatation and from_config overriding with explicit attn_implementation argument
|
||||
attn_implementation_available = ["eager"]
|
||||
if is_torch_sdpa_available():
|
||||
attn_implementation_available.append("sdpa")
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
attn_implementation_available.append("flash_attention_2")
|
||||
|
||||
mistral_attention_classes = {
|
||||
"eager": "MistralAttention",
|
||||
"sdpa": "MistralSdpaAttention",
|
||||
"flash_attention_2": "MistralFlashAttention2",
|
||||
}
|
||||
for requested_attn_implementation in attn_implementation_available:
|
||||
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation)
|
||||
# Ensure the config was set correctly
|
||||
self.assertEqual(config._attn_implementation, requested_attn_implementation)
|
||||
self.assertEqual(config._attn_implementation_internal, requested_attn_implementation)
|
||||
model = AutoModelForCausalLM.from_config(config)
|
||||
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
|
||||
for module in model.modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
self.assertEqual(
|
||||
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(TINY_MISTRAL)
|
||||
# When the config is not set, the default is "eager"
|
||||
self.assertEqual(config._attn_implementation, "eager")
|
||||
self.assertEqual(config._attn_implementation_internal, None)
|
||||
model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation)
|
||||
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
|
||||
for module in model.modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
self.assertEqual(
|
||||
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
|
||||
)
|
||||
|
||||
# Set a nonsense attn_implementation in the config, which should be overridden by the explicit argument
|
||||
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation="foo-bar-baz")
|
||||
self.assertEqual(config._attn_implementation, "foo-bar-baz")
|
||||
self.assertEqual(config._attn_implementation_internal, "foo-bar-baz")
|
||||
model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation)
|
||||
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
|
||||
for module in model.modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
self.assertEqual(
|
||||
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
|
||||
)
|
||||
|
||||
def test_torch_dtype_byte_sizes(self):
|
||||
torch_dtypes_and_bytes = [
|
||||
(torch.double, 8),
|
||||
|
||||
Reference in New Issue
Block a user