From aef12349b690541285a8ff6565561e8c45af2afe Mon Sep 17 00:00:00 2001 From: Shane A Date: Mon, 19 May 2025 06:35:23 -0700 Subject: [PATCH] Make HF implementation match original OLMo 2 models for lower precisions (#38131) * Make HF implementation match OLMo models for lower precisions * Add test of 1B logits in bfloat16 * Run make fixup --- src/transformers/models/olmo/modeling_olmo.py | 58 +++++++-------- src/transformers/models/olmo/modular_olmo.py | 44 ++++++++++- .../models/olmo2/modeling_olmo2.py | 74 +++++++++---------- .../models/olmo2/modular_olmo2.py | 16 +++- tests/models/olmo2/test_modeling_olmo2.py | 12 +++ 5 files changed, 135 insertions(+), 69 deletions(-) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 410dd25a96..72abff85dd 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -70,33 +70,6 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -135,6 +108,34 @@ def eager_attention_forward( return attn_output, attn_weights +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + q_type, k_type = q.dtype, k.dtype + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed.to(q_type), k_embed.to(k_type) + + class OlmoAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -302,8 +303,7 @@ class OlmoRotaryEmbedding(nn.Module): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + return cos, sin @auto_docstring diff --git a/src/transformers/models/olmo/modular_olmo.py b/src/transformers/models/olmo/modular_olmo.py index a2af171557..02ff85ac10 100644 --- a/src/transformers/models/olmo/modular_olmo.py +++ b/src/transformers/models/olmo/modular_olmo.py @@ -16,8 +16,8 @@ from ..llama.modeling_llama import ( LlamaModel, LlamaPreTrainedModel, LlamaRotaryEmbedding, - apply_rotary_pos_emb, eager_attention_forward, + rotate_half, ) from .configuration_olmo import OlmoConfig @@ -47,6 +47,34 @@ class OlmoMLP(LlamaMLP): self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + q_type, k_type = q.dtype, k.dtype + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed.to(q_type), k_embed.to(k_type) + + class OlmoAttention(LlamaAttention): def forward( self, @@ -115,8 +143,20 @@ class OlmoDecoderLayer(LlamaDecoderLayer): self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx) +# This is identical to LlamaRotaryEmbedding except the output cos and sin are returned +# as float32 rather than the input type. class OlmoRotaryEmbedding(LlamaRotaryEmbedding): - pass + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + return cos, sin class OlmoPreTrainedModel(LlamaPreTrainedModel): diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index f87befb459..58cad69374 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -48,46 +48,12 @@ class Olmo2RMSNorm(nn.Module): hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + return (self.weight * hidden_states).to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -126,6 +92,41 @@ def eager_attention_forward( return attn_output, attn_weights +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + q_type, k_type = q.dtype, k.dtype + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed.to(q_type), k_embed.to(k_type) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + class Olmo2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -306,8 +307,7 @@ class Olmo2RotaryEmbedding(nn.Module): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + return cos, sin @auto_docstring diff --git a/src/transformers/models/olmo2/modular_olmo2.py b/src/transformers/models/olmo2/modular_olmo2.py index c43263e954..a652709862 100644 --- a/src/transformers/models/olmo2/modular_olmo2.py +++ b/src/transformers/models/olmo2/modular_olmo2.py @@ -165,13 +165,27 @@ class Olmo2Config(OlmoConfig): del self.clip_qkv +# OLMo2 RMS norm is identical to Llama RMS norm except: +# - Weight and hidden states are multiplied before converting back to the input dtype, rather than after. class Olmo2RMSNorm(LlamaRMSNorm): - pass + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return (self.weight * hidden_states).to(input_dtype) ALL_LAYERNORM_LAYERS.append(Olmo2RMSNorm) +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + # Olmo2 attention is identical to OLMo attention except: # - Norm is applied to attention queries and keys. # - No qkv clipping. diff --git a/tests/models/olmo2/test_modeling_olmo2.py b/tests/models/olmo2/test_modeling_olmo2.py index 455c63551b..dcef3148d2 100644 --- a/tests/models/olmo2/test_modeling_olmo2.py +++ b/tests/models/olmo2/test_modeling_olmo2.py @@ -232,6 +232,18 @@ class Olmo2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi @require_torch class Olmo2IntegrationTest(unittest.TestCase): + @slow + def test_model_1b_logits_bfloat16(self): + input_ids = [[1, 306, 4658, 278, 6593, 310, 2834, 338]] + model = Olmo2ForCausalLM.from_pretrained("allenai/OLMo-2-0425-1B").to(torch.bfloat16) + out = model(torch.tensor(input_ids)).logits.float() + # Expected mean on dim = -1 + EXPECTED_MEAN = torch.tensor([[-5.7094, -6.5548, -3.2527, -2.7847, -5.5092, -4.5223, -4.8427, -4.6867]]) + torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2) + # slicing logits[0, 0, 0:30] + EXPECTED_SLICE = torch.tensor([2.4531, -5.7188, -5.1562, -4.8750, -6.7812, -4.0625, -4.4375, -4.5938, -7.5938, -5.0938, -3.9375, -3.6875, -5.0938, -3.1875, -5.6875, 0.2266, 1.2578, 1.1016, 0.8945, 0.4785, 0.2256, -0.3613, -0.4258, 0.1377, -0.1104, -7.1875, -5.2188, -6.8125, -0.9062, -2.9062]) # fmt: skip + torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-2, atol=1e-2) + @slow def test_model_7b_logits(self): input_ids = [[1, 306, 4658, 278, 6593, 310, 2834, 338]]