From 55090585619d7ab880d9a7d7c8b327a746f7cc40 Mon Sep 17 00:00:00 2001 From: Gustavo de Rosa Date: Thu, 11 Jan 2024 11:58:02 -0300 Subject: [PATCH] [Phi] Extend implementation to use GQA/MQA. (#28163) * chore(phi): Updates configuration_phi with missing keys. * chore(phi): Adds first draft of combined modeling_phi. * fix(phi): Fixes according to latest review. * fix(phi): Removes pad_vocab_size_multiple to prevent inconsistencies. * fix(phi): Fixes unit and integration tests. * fix(phi): Ensures that everything works with microsoft/phi-1 for first integration. * fix(phi): Fixes output of docstring generation. * fix(phi): Fixes according to latest review. * fix(phi): Fixes according to latest review. * fix(tests): Re-enables Phi-1.5 test. * fix(phi): Fixes attention overflow on PhiAttention (for Phi-2). * fix(phi): Improves how queries and keys are upcast. * fix(phi): Small updates on latest changes. --- .../models/phi/configuration_phi.py | 25 +++- src/transformers/models/phi/modeling_phi.py | 140 +++++++++--------- tests/models/phi/test_modeling_phi.py | 18 +-- 3 files changed, 101 insertions(+), 82 deletions(-) diff --git a/src/transformers/models/phi/configuration_phi.py b/src/transformers/models/phi/configuration_phi.py index 5025ef798f..1b495cc8e2 100644 --- a/src/transformers/models/phi/configuration_phi.py +++ b/src/transformers/models/phi/configuration_phi.py @@ -23,8 +23,9 @@ from ...utils import logging logger = logging.get_logger(__name__) PHI_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "susnato/phi-1_dev": "https://huggingface.co/susnato/phi-1_dev/resolve/main/config.json", - "susnato/phi-1_5_dev": "https://huggingface.co/susnato/phi-1_5_dev/resolve/main/config.json", + "microsoft/phi-1": "https://huggingface.co/microsoft/phi-1/resolve/main/config.json", + "microsoft/phi-1_5": "https://huggingface.co/microsoft/phi-1_5/resolve/main/config.json", + "microsoft/phi-2": "https://huggingface.co/microsoft/phi-2/resolve/main/config.json", } @@ -33,7 +34,7 @@ class PhiConfig(PretrainedConfig): This is the configuration class to store the configuration of a [`PhiModel`]. It is used to instantiate an Phi model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Phi - [susnato/phi-1_dev](https://huggingface.co/susnato/phi-1_dev). + [microsoft/phi-1](https://huggingface.co/microsoft/phi-1). Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -50,6 +51,14 @@ class PhiConfig(PretrainedConfig): Number of hidden layers in the Transformer decoder. num_attention_heads (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. resid_pdrop (`float`, *optional*, defaults to 0.0): Dropout probability for mlp outputs. embd_pdrop (`int`, *optional*, defaults to 0.0): @@ -83,7 +92,7 @@ class PhiConfig(PretrainedConfig): partial_rotary_factor (`float`, *optional*, defaults to 0.5): Percentage of the query and keys which will have rotary embedding. qk_layernorm (`bool`, *optional*, defaults to `False`): - Whether or not to normalize the Queries and Keys after projecting the hidden states + Whether or not to normalize the Queries and Keys after projecting the hidden states. bos_token_id (`int`, *optional*, defaults to 1): Denotes beginning of sequences token id. eos_token_id (`int`, *optional*, defaults to 2): @@ -95,7 +104,7 @@ class PhiConfig(PretrainedConfig): >>> from transformers import PhiModel, PhiConfig >>> # Initializing a Phi-1 style configuration - >>> configuration = PhiConfig.from_pretrained("susnato/phi-1_dev") + >>> configuration = PhiConfig.from_pretrained("microsoft/phi-1") >>> # Initializing a model from the configuration >>> model = PhiModel(configuration) @@ -114,6 +123,7 @@ class PhiConfig(PretrainedConfig): intermediate_size=8192, num_hidden_layers=24, num_attention_heads=32, + num_key_value_heads=None, resid_pdrop=0.0, embd_pdrop=0.0, attention_dropout=0.0, @@ -136,6 +146,11 @@ class PhiConfig(PretrainedConfig): self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads self.resid_pdrop = resid_pdrop self.embd_pdrop = embd_pdrop self.attention_dropout = attention_dropout diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index be568c62c7..f4d227e67c 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -54,12 +54,13 @@ if is_flash_attn_2_available(): logger = logging.get_logger(__name__) -_CHECKPOINT_FOR_DOC = "susnato/phi-1_dev" +_CHECKPOINT_FOR_DOC = "microsoft/phi-1" _CONFIG_FOR_DOC = "PhiConfig" PHI_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "susnato/phi-1_dev", - "susnato/phi-1_5_dev", + "microsoft/phi-1", + "microsoft/phi-1_5", + "microsoft/phi-2", # See all Phi models at https://huggingface.co/models?filter=phi ] @@ -214,7 +215,19 @@ class PhiMLP(nn.Module): return hidden_states -# Copied from transformers.models.persimmon.modeling_persimmon.PersimmonAttention with Persimmon->Phi,persimmon->phi +# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + class PhiAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -229,9 +242,12 @@ class PhiAttention(nn.Module): "when creating this class." ) + self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.partial_rotary_factor = config.partial_rotary_factor @@ -242,10 +258,13 @@ class PhiAttention(nn.Module): f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) - self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True) - self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True) - self.qk_layernorm = config.qk_layernorm + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True) + + self.qk_layernorm = config.qk_layernorm if self.qk_layernorm: self.q_layernorm = nn.LayerNorm( config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True @@ -253,7 +272,7 @@ class PhiAttention(nn.Module): self.k_layernorm = nn.LayerNorm( config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True ) - self.attention_dropout = nn.Dropout(config.attention_dropout) + self._init_rope() def _init_rope(self): @@ -283,23 +302,6 @@ class PhiAttention(nn.Module): else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - # Copied from transformers.models.bloom.modeling_bloom.BloomAttention._split_heads - def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory - storage as `fused_qkv` - - Args: - fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim] - - Returns: - query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim] - value: [batch_size, seq_length, num_heads, head_dim] - """ - batch_size, seq_length, three_times_hidden_size = fused_qkv.shape - fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) - return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] - def forward( self, hidden_states: torch.Tensor, @@ -311,20 +313,17 @@ class PhiAttention(nn.Module): ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - # [batch_size, seq_length, 3 x hidden_size] - fused_qkv = self.query_key_value(hidden_states) - - # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_states, key_states, value_states) = self._split_heads(fused_qkv) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) if self.qk_layernorm: query_states = self.q_layernorm(query_states) key_states = self.k_layernorm(key_states) - # [batch_size, num_heads, seq_length, head_dim] -> [batch_size, seq_length, num_heads, head_dim] - query_states = query_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -354,11 +353,16 @@ class PhiAttention(nn.Module): key_states = torch.cat((key_rot, key_pass), dim=-1) if past_key_value is not None: - # Specific to RoPE models with partial rotation cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow + attn_weights = torch.matmul( + query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3) + ) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( @@ -374,8 +378,8 @@ class PhiAttention(nn.Module): attn_weights = attn_weights + attention_mask # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query_states.dtype) - attn_weights = self.attention_dropout(attn_weights) + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) @@ -398,9 +402,9 @@ class PhiAttention(nn.Module): class PhiFlashAttention2(PhiAttention): """ - Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays untouched. - The only required change would be on the forward pass where it needs to correctly call the public API of flash - attention and deal with padding tokens in case the input contains any of them. + Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. """ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ @@ -415,11 +419,12 @@ class PhiFlashAttention2(PhiAttention): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # PhiFlashAttention2 attention does not support output_attentions @@ -427,20 +432,20 @@ class PhiFlashAttention2(PhiAttention): bsz, q_len, _ = hidden_states.size() - # [batch_size, seq_length, 3 x hidden_size] - fused_qkv = self.query_key_value(hidden_states) - - # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_states, key_states, value_states) = self._split_heads(fused_qkv) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) if self.qk_layernorm: query_states = self.q_layernorm(query_states) key_states = self.k_layernorm(key_states) - # [batch_size, num_heads, seq_length, head_dim] -> [batch_size, seq_length, num_heads, head_dim] - query_states = query_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -467,15 +472,13 @@ class PhiFlashAttention2(PhiAttention): cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - tgt_len = key_states.shape[2] + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - query_states = query_states.transpose(1, 2).view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) - value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) - - attn_dropout = self.config.attention_dropout if self.training else 0.0 + attn_dropout = self.attention_dropout if self.training else 0.0 # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need @@ -506,7 +509,7 @@ class PhiFlashAttention2(PhiAttention): query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=1.0 ) - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.dense(attn_output) if not output_attentions: @@ -708,6 +711,7 @@ class PhiPreTrainedModel(PreTrainedModel): config_class = PhiConfig base_model_prefix = "model" supports_gradient_checkpointing = True + _no_split_modules = ["PhiDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_cache_class = True @@ -745,7 +749,7 @@ PHI_INPUTS_DOCSTRING = r""" Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] @@ -852,13 +856,13 @@ class PhiModel(PhiPreTrainedModel): # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: - batch_size, seq_length = input_ids.shape + batch_size, seq_length = input_ids.shape[:2] elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape + batch_size, seq_length = inputs_embeds.shape[:2] else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + raise ValueError("You have to specify either input_ids or inputs_embeds") past_key_values_length = 0 @@ -1020,8 +1024,8 @@ class PhiForCausalLM(PhiPreTrainedModel): ```python >>> from transformers import AutoTokenizer, PhiForCausalLM - >>> model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev") - >>> tokenizer = AutoTokenizer.from_pretrained("susnato/phi-1_5_dev") + >>> model = PhiForCausalLM.from_pretrained("microsoft/phi-1") + >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1") >>> prompt = "This is an example script ." >>> inputs = tokenizer(prompt, return_tensors="pt") @@ -1029,7 +1033,7 @@ class PhiForCausalLM(PhiPreTrainedModel): >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - 'This is an example script .py file that uses the `os` module to create a new directory and write some text to it.\n\n``' + 'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str' ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions diff --git a/tests/models/phi/test_modeling_phi.py b/tests/models/phi/test_modeling_phi.py index f5fd51e98b..d69bbb32c1 100644 --- a/tests/models/phi/test_modeling_phi.py +++ b/tests/models/phi/test_modeling_phi.py @@ -365,18 +365,18 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, @require_bitsandbytes @pytest.mark.flash_attn_test @slow - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_flash_attn_2_generate_padding_right with LlamaForCausalLM->PhiForCausalLM,LlamaTokenizer->AutoTokenizer,meta-llama/Llama-2-7b-hf->susnato/phi-1_5_dev + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_flash_attn_2_generate_padding_right with LlamaForCausalLM->PhiForCausalLM,LlamaTokenizer->AutoTokenizer,meta-llama/Llama-2-7b-hf->microsoft/phi-1 def test_flash_attn_2_generate_padding_right(self): """ Overwritting the common test as the test is flaky on tiny models """ model = PhiForCausalLM.from_pretrained( - "susnato/phi-1_5_dev", + "microsoft/phi-1", load_in_4bit=True, device_map={"": 0}, ) - tokenizer = AutoTokenizer.from_pretrained("susnato/phi-1_5_dev") + tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1") texts = ["hi", "Hello this is a very long sentence"] @@ -389,7 +389,7 @@ class PhiModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, output_native = tokenizer.batch_decode(output_native) model = PhiForCausalLM.from_pretrained( - "susnato/phi-1_5_dev", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2" + "microsoft/phi-1", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_2" ) output_fa_2 = model.generate(**inputs, max_new_tokens=20, do_sample=False) @@ -408,7 +408,7 @@ class PhiIntegrationTest(unittest.TestCase): ) } - model = PhiForCausalLM.from_pretrained("susnato/phi-1_dev").to(torch_device) + model = PhiForCausalLM.from_pretrained("microsoft/phi-1").to(torch_device) model.eval() output = model(**input_ids).logits @@ -424,7 +424,7 @@ class PhiIntegrationTest(unittest.TestCase): ) } - model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev").to(torch_device) + model = PhiForCausalLM.from_pretrained("microsoft/phi-1_5").to(torch_device) model.eval() output = model(**input_ids).logits @@ -440,7 +440,7 @@ class PhiIntegrationTest(unittest.TestCase): ) } - model = PhiForCausalLM.from_pretrained("susnato/phi-2").to(torch_device) + model = PhiForCausalLM.from_pretrained("microsoft/phi-2").to(torch_device) model.eval() output = model(**input_ids).logits @@ -450,8 +450,8 @@ class PhiIntegrationTest(unittest.TestCase): self.assertTrue(torch.allclose(EXPECTED_OUTPUT, output[0, :2, :30], atol=1e-3, rtol=1e-3)) def test_phi_2_generation(self): - model = PhiForCausalLM.from_pretrained("susnato/phi-2") - tokenizer = AutoTokenizer.from_pretrained("susnato/phi-2") + model = PhiForCausalLM.from_pretrained("microsoft/phi-2") + tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2") inputs = tokenizer( "Can you help me write a formal email to a potential business partner proposing a joint venture?",