From 2f12e408225b1ebceb0d2f701ce419d46678dc31 Mon Sep 17 00:00:00 2001 From: Jonathan Tow <41410219+jon-tow@users.noreply.github.com> Date: Mon, 8 Apr 2024 17:51:58 -0400 Subject: [PATCH] [`StableLm`] Add QK normalization and Parallel Residual Support (#29745) * init: add StableLm 2 support * add integration test for parallel residual and qk layernorm * update(modeling): match qk norm naming for consistency with phi/persimmon * fix(tests): run fwd/bwd on random init test model to jitter norm weights off identity * `use_parallel_residual`: add copy pointer to `GPTNeoXLayer.forward` * refactor: rename head states var in `StableLmLayerNormPerHead` * tests: update test model and add generate check --- .../models/stablelm/configuration_stablelm.py | 9 +++ .../models/stablelm/modeling_stablelm.py | 64 ++++++++++++++++--- .../models/stablelm/test_modeling_stablelm.py | 34 ++++++++++ 3 files changed, 97 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/stablelm/configuration_stablelm.py b/src/transformers/models/stablelm/configuration_stablelm.py index f1e3ab4517..beb4af4d84 100644 --- a/src/transformers/models/stablelm/configuration_stablelm.py +++ b/src/transformers/models/stablelm/configuration_stablelm.py @@ -83,6 +83,11 @@ class StableLmConfig(PretrainedConfig): is an experimental feature, subject to breaking API changes in future versions. use_qkv_bias (`bool`, *optional*, defaults to `False`): Whether or not the model should use bias for qkv layers. + qk_layernorm (`bool`, *optional*, defaults to `False`): + Whether or not to normalize, per head, the Queries and Keys after projecting the hidden states. + use_parallel_residual (`bool`, *optional*, defaults to `False`): + Whether to use a "parallel" formulation in each Transformer layer, which can provide a slight training + speedup at large scales. hidden_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio after applying the MLP to the hidden states. attention_dropout (`float`, *optional*, defaults to 0.0): @@ -123,6 +128,8 @@ class StableLmConfig(PretrainedConfig): rope_theta=10_000, rope_scaling=None, use_qkv_bias=False, + qk_layernorm=False, + use_parallel_residual=False, hidden_dropout=0.0, attention_dropout=0.0, partial_rotary_factor=0.25, @@ -146,6 +153,8 @@ class StableLmConfig(PretrainedConfig): self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.use_qkv_bias = use_qkv_bias + self.qk_layernorm = qk_layernorm + self.use_parallel_residual = use_parallel_residual self.hidden_dropout = hidden_dropout self.attention_dropout = attention_dropout self.partial_rotary_factor = partial_rotary_factor diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 76aca7bae9..3262f2cd3c 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -203,6 +203,21 @@ class StableLmMLP(nn.Module): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) +class StableLmLayerNormPerHead(nn.Module): + def __init__(self, dim, num_heads, eps=1e-5, bias=False): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.norms = nn.ModuleList([nn.LayerNorm(dim, eps=eps, bias=bias) for _ in range(self.num_heads)]) + + def forward(self, hidden_states: torch.Tensor): + # Split along the num_heads axis to get per-head inputs + # [batch_size, num_heads, seq_len, head_dim] -> [batch_size, 1, seq_len, head_dim] * num_heads + states_per_heads = torch.split(hidden_states, 1, dim=1) + # Normalize and merge the heads back together + return torch.cat([norm(hidden_states) for norm, hidden_states in zip(self.norms, states_per_heads)], dim=1) + + # Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ @@ -250,6 +265,13 @@ class StableLmAttention(nn.Module): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_qkv_bias) self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.qk_layernorm = config.qk_layernorm + if self.qk_layernorm: + self.q_layernorm = StableLmLayerNormPerHead(self.head_dim, self.num_heads, eps=config.layer_norm_eps) + self.k_layernorm = StableLmLayerNormPerHead( + self.head_dim, self.num_key_value_heads, eps=config.layer_norm_eps + ) + self.attention_dropout = nn.Dropout(config.attention_dropout) self._init_rope() @@ -300,6 +322,10 @@ class StableLmAttention(nn.Module): key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + kv_seq_len = key_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: @@ -409,6 +435,10 @@ class StableLmSdpaAttention(StableLmAttention): key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + kv_seq_len = key_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: @@ -513,6 +543,10 @@ class StableLmFlashAttention2(StableLmAttention): key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + kv_seq_len = key_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: @@ -678,11 +712,14 @@ ATTENTION_CLASSES = { class StableLmDecoderLayer(nn.Module): def __init__(self, config: StableLmConfig, layer_idx: int): super().__init__() + self.use_parallel_residual = config.use_parallel_residual self.hidden_size = config.hidden_size self.self_attn = ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) self.mlp = StableLmMLP(config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_attention_layernorm = None + if not self.use_parallel_residual: + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout) def forward( @@ -719,7 +756,7 @@ class StableLmDecoderLayer(nn.Module): hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + self_attn_output, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -727,15 +764,22 @@ class StableLmDecoderLayer(nn.Module): output_attentions=output_attentions, use_cache=use_cache, ) - hidden_states = residual + hidden_states - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - - hidden_states = self.dropout(hidden_states) - hidden_states = hidden_states + residual + # copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer.forward + if self.use_parallel_residual: + # x = x + attn(ln1(x)) + mlp(ln1(x)) + # Fully Connected + mlp_output = self.mlp(hidden_states) + mlp_output = self.dropout(mlp_output) + hidden_states = residual + self_attn_output + mlp_output + else: + # x = x + attn(ln1(x)) + # x = x + mlp(ln2(x)) + residual = residual + self_attn_output + # Fully Connected + mlp_output = self.mlp(self.post_attention_layernorm(residual)) + mlp_output = self.dropout(mlp_output) + hidden_states = residual + mlp_output outputs = (hidden_states,) diff --git a/tests/models/stablelm/test_modeling_stablelm.py b/tests/models/stablelm/test_modeling_stablelm.py index 0f01fa8fca..f5e74ead9b 100644 --- a/tests/models/stablelm/test_modeling_stablelm.py +++ b/tests/models/stablelm/test_modeling_stablelm.py @@ -483,6 +483,40 @@ class StableLmModelIntegrationTest(unittest.TestCase): EXPECTED_TEXT_COMPLETION = """My favorite food has always been pizza, but lately I’ve been craving something different. I’ve been trying to eat healthier and I’ve""" self.assertEqual(text, EXPECTED_TEXT_COMPLETION) + @slow + def test_model_tiny_random_stablelm_2_logits(self): + # Check parallel residual and qk layernorm forward pass + input_ids = {"input_ids": torch.tensor([[510, 8588, 310, 1900, 9386]], dtype=torch.long, device=torch_device)} + + model = StableLmForCausalLM.from_pretrained("stabilityai/tiny-random-stablelm-2").to(torch_device) + model.eval() + + output = model(**input_ids).logits + + # Expected mean on dim = -1 + EXPECTED_MEAN = torch.tensor([[-2.7196, -3.6099, -2.6877, -3.1973, -3.9344]]).to(torch_device) + self.assertTrue(torch.allclose(output.mean(dim=-1), EXPECTED_MEAN, atol=1e-4, rtol=1e-4)) + + # Expected logits sliced from [0, 0, 0:30] + EXPECTED_SLICE = torch.tensor([2.8364, 5.3811, 5.1659, 7.5485, 4.3219, 6.3315, 1.3967, 6.9147, 3.9679, 6.4786, 5.9176, 3.3067, 5.2917, 0.1485, 3.9630, 7.9947,10.6727, 9.6757, 8.8772, 8.3527, 7.8445, 6.6025, 5.5786, 7.0985,6.1369, 3.4259, 1.9397, 4.6157, 4.8105, 3.1768]).to(torch_device) # fmt: skip + self.assertTrue(torch.allclose(output[0, 0, :30], EXPECTED_SLICE, atol=1e-4, rtol=1e-4)) + + @slow + def test_model_tiny_random_stablelm_2_generation(self): + # Check parallel residual and qk layernorm generation + tokenizer = AutoTokenizer.from_pretrained("stabilityai/tiny-random-stablelm-2") + model = StableLmForCausalLM.from_pretrained("stabilityai/tiny-random-stablelm-2") + input_ids = tokenizer.encode( + "My favorite ride at the amusement park", + return_tensors="pt", + ) + + outputs = model.generate(input_ids, max_new_tokens=20, temperature=0) + text = tokenizer.decode(outputs[0], skip_special_tokens=True) + + EXPECTED_TEXT_COMPLETION = """My favorite ride at the amusement park is the 2000-mile roller coaster. It's a thrilling ride filled with roller coast""" + self.assertEqual(text, EXPECTED_TEXT_COMPLETION) + @require_bitsandbytes @slow @require_flash_attn