From 21c912e79c8ee62034177bd43c9c628be9b46e2a Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Sat, 20 Apr 2024 00:45:53 +0800 Subject: [PATCH] Fix config + attn_implementation in AutoModelForCausalLM.from_pretrained (#30299) * Update modeling_utils.py * Update test_modeling_utils.py * Update test_modeling_utils.py * Update test_modeling_utils.py --- src/transformers/modeling_utils.py | 2 +- tests/test_modeling_utils.py | 38 ++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e4fee8a526..f9ebd42a17 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3146,7 +3146,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix config = copy.deepcopy(config) kwarg_attn_imp = kwargs.pop("attn_implementation", None) - if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp: + if kwarg_attn_imp is not None: config._attn_implementation = kwarg_attn_imp model_kwargs = kwargs diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index b6c1e99737..37ae919a44 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -427,6 +427,44 @@ class ModelUtilsTest(TestCasePlus): model = AutoModel.from_pretrained(TINY_BERT_FOR_TOKEN_CLASSIFICATION, torch_dtype="auto") self.assertEqual(model.dtype, torch.float32) + def test_model_from_pretrained_attn_implementation(self): + # test that the model can be instantiated with attn_implementation of either + # 1. explicit from_pretrained's attn_implementation argument + # 2. explicit from_pretrained's attn_implementation argument with a config argument + attn_implementation_available = ["eager"] + if is_torch_sdpa_available(): + attn_implementation_available.append("sdpa") + + if is_flash_attn_2_available(): + attn_implementation_available.append("flash_attention_2") + + mistral_attention_classes = { + "eager": "MistralAttention", + "sdpa": "MistralSdpaAttention", + "flash_attention_2": "MistralFlashAttention2", + } + for requested_attn_implementation in attn_implementation_available: + model = AutoModelForCausalLM.from_pretrained( + TINY_MISTRAL, attn_implementation=requested_attn_implementation + ) + self.assertEqual(model.config._attn_implementation, requested_attn_implementation) + for module in model.modules(): + if "Attention" in module.__class__.__name__: + self.assertEqual( + module.__class__.__name__, mistral_attention_classes[requested_attn_implementation] + ) + + config = AutoConfig.from_pretrained(TINY_MISTRAL) + model = AutoModelForCausalLM.from_pretrained( + TINY_MISTRAL, config=config, attn_implementation=requested_attn_implementation + ) + self.assertEqual(model.config._attn_implementation, requested_attn_implementation) + for module in model.modules(): + if "Attention" in module.__class__.__name__: + self.assertEqual( + module.__class__.__name__, mistral_attention_classes[requested_attn_implementation] + ) + def test_no_super_init_config_and_model(self): config = NoSuperInitConfig(attribute=32) model = NoSuperInitModel(config)