Respect the config's attn_implementation if set (#32383)
* Respect the config's attn if set * Update test - can override in from_config * Fix
This commit is contained in:
@@ -1454,7 +1454,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
||||||
|
|
||||||
config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config.
|
config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config.
|
||||||
config._attn_implementation = kwargs.pop("attn_implementation", None)
|
|
||||||
|
if config._attn_implementation_internal is not None:
|
||||||
|
# In this case, the config has been created with the attn_implementation set by the user, which we
|
||||||
|
# should respect.
|
||||||
|
attn_implementation = config._attn_implementation_internal
|
||||||
|
else:
|
||||||
|
attn_implementation = None
|
||||||
|
|
||||||
|
config._attn_implementation = kwargs.pop("attn_implementation", attn_implementation)
|
||||||
config = cls._autoset_attn_implementation(
|
config = cls._autoset_attn_implementation(
|
||||||
config,
|
config,
|
||||||
use_flash_attention_2=use_flash_attention_2,
|
use_flash_attention_2=use_flash_attention_2,
|
||||||
|
|||||||
@@ -574,6 +574,60 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
|
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_model_from_config_attn_implementation(self):
|
||||||
|
# test that the model can be instantiated with attn_implementation of either
|
||||||
|
# 1. config created with explicit attn_implementatation and from_config
|
||||||
|
# 2. explicit from_config's attn_implementation argument with a config argument
|
||||||
|
# 3. config created with explicit attn_implementatation and from_config overriding with explicit attn_implementation argument
|
||||||
|
attn_implementation_available = ["eager"]
|
||||||
|
if is_torch_sdpa_available():
|
||||||
|
attn_implementation_available.append("sdpa")
|
||||||
|
|
||||||
|
if is_flash_attn_2_available():
|
||||||
|
attn_implementation_available.append("flash_attention_2")
|
||||||
|
|
||||||
|
mistral_attention_classes = {
|
||||||
|
"eager": "MistralAttention",
|
||||||
|
"sdpa": "MistralSdpaAttention",
|
||||||
|
"flash_attention_2": "MistralFlashAttention2",
|
||||||
|
}
|
||||||
|
for requested_attn_implementation in attn_implementation_available:
|
||||||
|
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation)
|
||||||
|
# Ensure the config was set correctly
|
||||||
|
self.assertEqual(config._attn_implementation, requested_attn_implementation)
|
||||||
|
self.assertEqual(config._attn_implementation_internal, requested_attn_implementation)
|
||||||
|
model = AutoModelForCausalLM.from_config(config)
|
||||||
|
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
|
||||||
|
for module in model.modules():
|
||||||
|
if "Attention" in module.__class__.__name__:
|
||||||
|
self.assertEqual(
|
||||||
|
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
|
||||||
|
)
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(TINY_MISTRAL)
|
||||||
|
# When the config is not set, the default is "eager"
|
||||||
|
self.assertEqual(config._attn_implementation, "eager")
|
||||||
|
self.assertEqual(config._attn_implementation_internal, None)
|
||||||
|
model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation)
|
||||||
|
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
|
||||||
|
for module in model.modules():
|
||||||
|
if "Attention" in module.__class__.__name__:
|
||||||
|
self.assertEqual(
|
||||||
|
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set a nonsense attn_implementation in the config, which should be overridden by the explicit argument
|
||||||
|
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation="foo-bar-baz")
|
||||||
|
self.assertEqual(config._attn_implementation, "foo-bar-baz")
|
||||||
|
self.assertEqual(config._attn_implementation_internal, "foo-bar-baz")
|
||||||
|
model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation)
|
||||||
|
self.assertEqual(model.config._attn_implementation, requested_attn_implementation)
|
||||||
|
for module in model.modules():
|
||||||
|
if "Attention" in module.__class__.__name__:
|
||||||
|
self.assertEqual(
|
||||||
|
module.__class__.__name__, mistral_attention_classes[requested_attn_implementation]
|
||||||
|
)
|
||||||
|
|
||||||
def test_torch_dtype_byte_sizes(self):
|
def test_torch_dtype_byte_sizes(self):
|
||||||
torch_dtypes_and_bytes = [
|
torch_dtypes_and_bytes = [
|
||||||
(torch.double, 8),
|
(torch.double, 8),
|
||||||
|
|||||||
Reference in New Issue
Block a user