From 7e5d46ded433605a906fcab6be43ac85307cca9b Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Mon, 5 Aug 2024 16:33:19 +0100 Subject: [PATCH] Respect the config's attn_implementation if set (#32383) * Respect the config's attn if set * Update test - can override in from_config * Fix --- src/transformers/modeling_utils.py | 10 +++++- tests/utils/test_modeling_utils.py | 54 ++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 651d207282..3fd364d702 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1454,7 +1454,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix dtype_orig = cls._set_default_torch_dtype(torch_dtype) config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config. - config._attn_implementation = kwargs.pop("attn_implementation", None) + + if config._attn_implementation_internal is not None: + # In this case, the config has been created with the attn_implementation set by the user, which we + # should respect. + attn_implementation = config._attn_implementation_internal + else: + attn_implementation = None + + config._attn_implementation = kwargs.pop("attn_implementation", attn_implementation) config = cls._autoset_attn_implementation( config, use_flash_attention_2=use_flash_attention_2, diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 720731f39c..71c72f9212 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -574,6 +574,60 @@ class ModelUtilsTest(TestCasePlus): module.__class__.__name__, mistral_attention_classes[requested_attn_implementation] ) + def test_model_from_config_attn_implementation(self): + # test that the model can be instantiated with attn_implementation of either + # 1. config created with explicit attn_implementatation and from_config + # 2. explicit from_config's attn_implementation argument with a config argument + # 3. config created with explicit attn_implementatation and from_config overriding with explicit attn_implementation argument + attn_implementation_available = ["eager"] + if is_torch_sdpa_available(): + attn_implementation_available.append("sdpa") + + if is_flash_attn_2_available(): + attn_implementation_available.append("flash_attention_2") + + mistral_attention_classes = { + "eager": "MistralAttention", + "sdpa": "MistralSdpaAttention", + "flash_attention_2": "MistralFlashAttention2", + } + for requested_attn_implementation in attn_implementation_available: + config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation) + # Ensure the config was set correctly + self.assertEqual(config._attn_implementation, requested_attn_implementation) + self.assertEqual(config._attn_implementation_internal, requested_attn_implementation) + model = AutoModelForCausalLM.from_config(config) + self.assertEqual(model.config._attn_implementation, requested_attn_implementation) + for module in model.modules(): + if "Attention" in module.__class__.__name__: + self.assertEqual( + module.__class__.__name__, mistral_attention_classes[requested_attn_implementation] + ) + + config = AutoConfig.from_pretrained(TINY_MISTRAL) + # When the config is not set, the default is "eager" + self.assertEqual(config._attn_implementation, "eager") + self.assertEqual(config._attn_implementation_internal, None) + model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation) + self.assertEqual(model.config._attn_implementation, requested_attn_implementation) + for module in model.modules(): + if "Attention" in module.__class__.__name__: + self.assertEqual( + module.__class__.__name__, mistral_attention_classes[requested_attn_implementation] + ) + + # Set a nonsense attn_implementation in the config, which should be overridden by the explicit argument + config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation="foo-bar-baz") + self.assertEqual(config._attn_implementation, "foo-bar-baz") + self.assertEqual(config._attn_implementation_internal, "foo-bar-baz") + model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation) + self.assertEqual(model.config._attn_implementation, requested_attn_implementation) + for module in model.modules(): + if "Attention" in module.__class__.__name__: + self.assertEqual( + module.__class__.__name__, mistral_attention_classes[requested_attn_implementation] + ) + def test_torch_dtype_byte_sizes(self): torch_dtypes_and_bytes = [ (torch.double, 8),