Llama: fix batched generation (#29109)
This commit is contained in:
@@ -101,11 +101,34 @@ class LlamaRotaryEmbedding(nn.Module):
|
|||||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sin_cached(self):
|
||||||
|
logger.warning_once(
|
||||||
|
"The sin_cached attribute will be removed in 4.40. Bear in mind that its contents changed in v4.38. Use "
|
||||||
|
"the forward method of RoPE from now on instead."
|
||||||
|
)
|
||||||
|
return self._sin_cached
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cos_cached(self):
|
||||||
|
logger.warning_once(
|
||||||
|
"The cos_cached attribute will be removed in 4.40. Bear in mind that its contents changed in v4.38. Use "
|
||||||
|
"the forward method of RoPE from now on instead."
|
||||||
|
)
|
||||||
|
return self._cos_cached
|
||||||
|
|
||||||
def forward(self, x, position_ids, seq_len=None):
|
def forward(self, x, position_ids, seq_len=None):
|
||||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||||
freqs = (self.inv_freq[:, None].float().expand(-1, position_ids.shape[0]) @ (position_ids.float())).t()
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||||
|
position_ids_expanded = position_ids[:, None, :].float()
|
||||||
|
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype)
|
cos = emb.cos().to(dtype=x.dtype)
|
||||||
|
sin = emb.sin().to(dtype=x.dtype)
|
||||||
|
# backwards compatibility
|
||||||
|
self._cos_cached = cos
|
||||||
|
self._sin_cached = sin
|
||||||
|
return cos, sin
|
||||||
|
|
||||||
|
|
||||||
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||||
@@ -181,6 +204,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
|||||||
Returns:
|
Returns:
|
||||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||||
"""
|
"""
|
||||||
|
cos = cos.unsqueeze(unsqueeze_dim)
|
||||||
|
sin = sin.unsqueeze(unsqueeze_dim)
|
||||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
@@ -1033,6 +1058,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
|
|
||||||
batch_size, seq_length = input_tensor.shape[:2]
|
batch_size, seq_length = input_tensor.shape[:2]
|
||||||
dtype = input_tensor.dtype
|
dtype = input_tensor.dtype
|
||||||
|
device = input_tensor.device
|
||||||
|
|
||||||
# support going beyond cached `max_position_embedding`
|
# support going beyond cached `max_position_embedding`
|
||||||
if seq_length > self.causal_mask.shape[-1]:
|
if seq_length > self.causal_mask.shape[-1]:
|
||||||
@@ -1048,8 +1074,9 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
(self.config.max_position_embeddings, self.config.max_position_embeddings),
|
(self.config.max_position_embeddings, self.config.max_position_embeddings),
|
||||||
fill_value=torch.finfo(dtype).min,
|
fill_value=torch.finfo(dtype).min,
|
||||||
)
|
)
|
||||||
causal_mask = torch.triu(mask, diagonal=1).to(dtype)
|
causal_mask = torch.triu(mask, diagonal=1)
|
||||||
|
|
||||||
|
causal_mask = causal_mask.to(dtype=dtype, device=device)
|
||||||
if attention_mask is not None and attention_mask.dim() == 2:
|
if attention_mask is not None and attention_mask.dim() == 2:
|
||||||
mask_length = attention_mask.shape[-1]
|
mask_length = attention_mask.shape[-1]
|
||||||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
||||||
|
|||||||
@@ -293,7 +293,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
|||||||
@parameterized.expand(["eager", "sdpa", "flash_attention_2"])
|
@parameterized.expand(["eager", "sdpa", "flash_attention_2"])
|
||||||
def test_static_cache_greedy_sampling_pad_left(self, attn_implementation):
|
def test_static_cache_greedy_sampling_pad_left(self, attn_implementation):
|
||||||
EXPECTED_GENERATION = [
|
EXPECTED_GENERATION = [
|
||||||
"The best color is the one that complements the subject you are photograph",
|
"The best color is the one that complements the skin tone of the",
|
||||||
"We should not undermind the issues at hand.\nWe should not undermind the issues",
|
"We should not undermind the issues at hand.\nWe should not undermind the issues",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -333,18 +333,18 @@ class CacheIntegrationTest(unittest.TestCase):
|
|||||||
@parameterized.expand(["eager", "sdpa", "flash_attention_2"])
|
@parameterized.expand(["eager", "sdpa", "flash_attention_2"])
|
||||||
def test_static_cache_greedy_sampling_pad_right(self, attn_implementation):
|
def test_static_cache_greedy_sampling_pad_right(self, attn_implementation):
|
||||||
EXPECTED_GENERATION = [
|
EXPECTED_GENERATION = [
|
||||||
"The best color is\n\n\n\n\n\n\n\n\n\n",
|
"The best color isЋ the one that complements the skin tone of",
|
||||||
"We should not undermind the issues at hand, but address them head on.\nI think",
|
"We should not undermind the issues at hand.\nWe should not undermind the issues",
|
||||||
]
|
]
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
"NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
|
"NousResearch/Llama-2-7b-chat-hf", padding_side="right", pad_token="<s>"
|
||||||
)
|
)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
"NousResearch/Llama-2-7b-chat-hf",
|
"NousResearch/Llama-2-7b-chat-hf",
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
attn_implementation=attn_implementation,
|
attn_implementation=attn_implementation,
|
||||||
).to("cuda:1")
|
).to(torch_device)
|
||||||
inputs = tokenizer(
|
inputs = tokenizer(
|
||||||
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
|
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
|
||||||
).to(model.device)
|
).to(model.device)
|
||||||
|
|||||||
Reference in New Issue
Block a user