From 64845307b362f4dfff1a783d7bce0f3407e92c34 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 6 Oct 2023 11:25:15 +0200 Subject: [PATCH] Remove unnecessary unsqueeze - squeeze in rotary positional embedding (#26162) * remove unnecessary unsqueeze-squeeze in llama * correct other models * fix * revert gpt_neox_japanese * fix copie * fix test --- .../open_llama/modeling_open_llama.py | 23 ++++++++---------- .../models/falcon/modeling_falcon.py | 16 ++++++------- .../models/gpt_neox/modeling_gpt_neox.py | 23 ++++++++---------- .../modeling_gpt_neox_japanese.py | 8 +++---- .../models/idefics/modeling_idefics.py | 15 +++++------- .../models/llama/modeling_llama.py | 24 +++++++++---------- .../models/mistral/modeling_mistral.py | 15 +++++------- .../models/persimmon/modeling_persimmon.py | 23 ++++++++---------- tests/models/mistral/test_modeling_mistral.py | 3 ++- 9 files changed, 67 insertions(+), 83 deletions(-) diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index ee017792bb..c975aa4087 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -118,8 +118,8 @@ class OpenLlamaRotaryEmbedding(nn.Module): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] @@ -127,8 +127,8 @@ class OpenLlamaRotaryEmbedding(nn.Module): self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), ) @@ -148,8 +148,8 @@ class OpenLlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->OpenLlama @@ -175,8 +175,8 @@ class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def rotate_half(x): @@ -188,11 +188,8 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim] + sin = sin[position_ids].unsqueeze(1) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 14d1e3c0ad..a82ff813ad 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -115,8 +115,8 @@ class FalconRotaryEmbedding(nn.Module): if dtype in [torch.float16, torch.bfloat16]: emb = emb.float() - self.cos_cached = emb.cos()[None, :, :] - self.sin_cached = emb.sin()[None, :, :] + self.cos_cached = emb.cos() + self.sin_cached = emb.sin() self.cos_cached = self.cos_cached.type(dtype) self.sin_cached = self.sin_cached.type(dtype) @@ -133,8 +133,8 @@ class FalconRotaryEmbedding(nn.Module): self.sin_cached = self.sin_cached.to(device) # Gather cos, sin at the designated position ids - cos = self.cos_cached.squeeze(0)[position_ids] # [bs, seq_len, dim] - sin = self.sin_cached.squeeze(0)[position_ids] # [bs, seq_len, dim] + cos = self.cos_cached[position_ids] # [bs, seq_len, dim] + sin = self.sin_cached[position_ids] # [bs, seq_len, dim] return cos, sin def forward(self, query, key, past_key_values_length, position_ids): @@ -181,8 +181,8 @@ class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding): if dtype in [torch.float16, torch.bfloat16]: emb = emb.float() - self.cos_cached = emb.cos()[None, :, :] - self.sin_cached = emb.sin()[None, :, :] + self.cos_cached = emb.cos() + self.sin_cached = emb.sin() self.cos_cached = self.cos_cached.type(dtype) self.sin_cached = self.sin_cached.type(dtype) @@ -215,8 +215,8 @@ class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding): if dtype in [torch.float16, torch.bfloat16]: emb = emb.float() - self.cos_cached = emb.cos()[None, :, :] - self.sin_cached = emb.sin()[None, :, :] + self.cos_cached = emb.cos() + self.sin_cached = emb.sin() self.cos_cached = self.cos_cached.type(dtype) self.sin_cached = self.sin_cached.type(dtype) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index d20de65985..b4aa415445 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -309,8 +309,8 @@ class GPTNeoXRotaryEmbedding(nn.Module): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] @@ -318,8 +318,8 @@ class GPTNeoXRotaryEmbedding(nn.Module): self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), ) @@ -339,8 +339,8 @@ class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->GPTNeoX @@ -366,8 +366,8 @@ class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def rotate_half(x): @@ -379,11 +379,8 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim] + sin = sin[position_ids].unsqueeze(1) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 5e94096a6d..98753edeb5 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -261,8 +261,8 @@ class RotaryEmbedding(nn.Module): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] @@ -270,8 +270,8 @@ class RotaryEmbedding(nn.Module): self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), ) diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 28f6ad9fb5..316f365613 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -523,8 +523,8 @@ class IdeficsEmbedding(torch.nn.Module): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] @@ -532,8 +532,8 @@ class IdeficsEmbedding(torch.nn.Module): self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), ) @@ -546,11 +546,8 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim] + sin = sin[position_ids].unsqueeze(1) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8a335eed3c..55753d5f75 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -138,8 +138,8 @@ class LlamaRotaryEmbedding(nn.Module): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] @@ -147,8 +147,8 @@ class LlamaRotaryEmbedding(nn.Module): self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), ) @@ -167,8 +167,8 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): @@ -193,8 +193,8 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def rotate_half(x): @@ -204,12 +204,10 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +# Copied from transformers.models.gpt_neox.modeling_gpt_neox.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim] + sin = sin[position_ids].unsqueeze(1) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 8edd0fc60e..a55f16a23d 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -149,8 +149,8 @@ class MistralRotaryEmbedding(nn.Module): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] @@ -158,8 +158,8 @@ class MistralRotaryEmbedding(nn.Module): self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), ) @@ -173,11 +173,8 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim] + sin = sin[position_ids].unsqueeze(1) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 4d3b9af2e4..c09657c065 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -94,8 +94,8 @@ class PersimmonRotaryEmbedding(nn.Module): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] @@ -103,8 +103,8 @@ class PersimmonRotaryEmbedding(nn.Module): self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), ) @@ -124,8 +124,8 @@ class PersimmonLinearScalingRotaryEmbedding(PersimmonRotaryEmbedding): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Persimmon @@ -151,8 +151,8 @@ class PersimmonDynamicNTKScalingRotaryEmbedding(PersimmonRotaryEmbedding): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) # Copied from transformers.models.llama.modeling_llama.rotate_half @@ -165,11 +165,8 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim] + sin = sin[position_ids].unsqueeze(1) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 403f2cc734..df1143d251 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -430,7 +430,8 @@ class MistralIntegrationTest(unittest.TestCase): def test_model_7b_logits(self): input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", device_map="auto") - out = model(torch.tensor([input_ids])).logits + input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) + out = model(input_ids).logits.cpu() # Expected mean on dim = -1 EXPECTED_MEAN = torch.tensor([[-2.5548, -2.5737, -3.0600, -2.5906, -2.8478, -2.8118, -2.9325, -2.7694]]) torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)