Falcon: batched generation (#26137)
This commit is contained in:
@@ -67,6 +67,7 @@ def rotate_half(x):
|
|||||||
return torch.cat((-x2, x1), dim=-1)
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO (joao): Is this the same implementation as in Llama? If so, let's make them the same and add the copy facilities
|
||||||
class FalconRotaryEmbedding(nn.Module):
|
class FalconRotaryEmbedding(nn.Module):
|
||||||
"""Implementation of RotaryEmbedding from GPT-NeoX.
|
"""Implementation of RotaryEmbedding from GPT-NeoX.
|
||||||
This implementation is designed to operate on queries and keys that are compatible with `[batch_size,
|
This implementation is designed to operate on queries and keys that are compatible with `[batch_size,
|
||||||
@@ -99,19 +100,40 @@ class FalconRotaryEmbedding(nn.Module):
|
|||||||
self.cos_cached = self.cos_cached.type(dtype)
|
self.cos_cached = self.cos_cached.type(dtype)
|
||||||
self.sin_cached = self.sin_cached.type(dtype)
|
self.sin_cached = self.sin_cached.type(dtype)
|
||||||
|
|
||||||
def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
|
def cos_sin(
|
||||||
|
self, seq_len: int, past_key_values_length: int, position_ids: torch.Tensor, device="cpu", dtype=torch.bfloat16
|
||||||
|
) -> torch.Tensor:
|
||||||
total_length = seq_len + past_key_values_length
|
total_length = seq_len + past_key_values_length
|
||||||
if total_length > self.seq_len_cached:
|
if total_length > self.seq_len_cached:
|
||||||
self._set_cos_sin_cache(total_length, device, dtype)
|
self._set_cos_sin_cache(total_length, device, dtype)
|
||||||
return (
|
# Gather cos, sin at the designated position ids
|
||||||
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length],
|
cos = self.cos_cached.squeeze(0)[position_ids] # [bs, seq_len, dim]
|
||||||
self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length],
|
sin = self.sin_cached.squeeze(0)[position_ids] # [bs, seq_len, dim]
|
||||||
)
|
return cos, sin
|
||||||
|
|
||||||
def forward(self, query, key, past_key_values_length=0):
|
def forward(self, query, key, past_key_values_length, position_ids):
|
||||||
batch, seq_len, head_dim = query.shape
|
_, seq_len, _ = query.shape
|
||||||
cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)
|
cos, sin = self.cos_sin(seq_len, past_key_values_length, position_ids, query.device, query.dtype)
|
||||||
return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
|
# Query and key's shapes are [bs * num_heads, seq_len, dim], might need manual expansion. Ifs and elses used to
|
||||||
|
# avoid unnecessary repeat_interleave operations.
|
||||||
|
query_expansion_factor = int(query.shape[0] / cos.shape[0])
|
||||||
|
if query_expansion_factor > 1:
|
||||||
|
query_cos = torch.repeat_interleave(cos, query_expansion_factor, dim=0)
|
||||||
|
query_sin = torch.repeat_interleave(sin, query_expansion_factor, dim=0)
|
||||||
|
else:
|
||||||
|
query_cos, query_sin = cos, sin
|
||||||
|
|
||||||
|
key_expansion_factor = int(key.shape[0] / cos.shape[0])
|
||||||
|
if key_expansion_factor > 1:
|
||||||
|
if key_expansion_factor != query_expansion_factor:
|
||||||
|
key_cos = torch.repeat_interleave(cos, key_expansion_factor, dim=0)
|
||||||
|
key_sin = torch.repeat_interleave(sin, key_expansion_factor, dim=0)
|
||||||
|
else:
|
||||||
|
key_cos, key_sin = query_cos, query_sin
|
||||||
|
else:
|
||||||
|
key_cos, key_sin = cos, sin
|
||||||
|
|
||||||
|
return (query * query_cos) + (rotate_half(query) * query_sin), (key * key_cos) + (rotate_half(key) * key_sin)
|
||||||
|
|
||||||
|
|
||||||
class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
|
class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
|
||||||
@@ -270,7 +292,7 @@ class FalconAttention(nn.Module):
|
|||||||
f" {self.num_heads})."
|
f" {self.num_heads})."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.maybe_rotary = self._init_rope() if config.rotary else lambda q, k, t: (q, k)
|
self.maybe_rotary = self._init_rope() if config.rotary else lambda q, k, t, p: (q, k)
|
||||||
|
|
||||||
# Layer-wise attention scaling
|
# Layer-wise attention scaling
|
||||||
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
||||||
@@ -378,6 +400,7 @@ class FalconAttention(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
alibi: Optional[torch.Tensor],
|
alibi: Optional[torch.Tensor],
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
@@ -399,7 +422,7 @@ class FalconAttention(nn.Module):
|
|||||||
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
|
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
|
||||||
|
|
||||||
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
|
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
|
||||||
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
|
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length, position_ids)
|
||||||
|
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
past_key, past_value = layer_past
|
past_key, past_value = layer_past
|
||||||
@@ -415,7 +438,8 @@ class FalconAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
present = None
|
present = None
|
||||||
|
|
||||||
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
|
float_min = torch.finfo(query_layer.dtype).min
|
||||||
|
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float_min).to(query_layer.dtype)
|
||||||
|
|
||||||
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
||||||
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
|
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
|
||||||
@@ -536,6 +560,7 @@ class FalconDecoderLayer(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
alibi: Optional[torch.Tensor],
|
alibi: Optional[torch.Tensor],
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
@@ -554,6 +579,7 @@ class FalconDecoderLayer(nn.Module):
|
|||||||
attention_layernorm_out,
|
attention_layernorm_out,
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
alibi=alibi,
|
alibi=alibi,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -632,6 +658,11 @@ FALCON_INPUTS_DOCSTRING = r"""
|
|||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
[What are attention masks?](../glossary#attention-mask)
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||||
|
config.n_positions - 1]`.
|
||||||
|
|
||||||
|
[What are position IDs?](../glossary#position-ids)
|
||||||
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
||||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
@@ -836,6 +867,7 @@ class FalconModel(FalconPreTrainedModel):
|
|||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask: Optional[torch.LongTensor] = None,
|
head_mask: Optional[torch.LongTensor] = None,
|
||||||
inputs_embeds: Optional[torch.LongTensor] = None,
|
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
@@ -892,6 +924,14 @@ class FalconModel(FalconPreTrainedModel):
|
|||||||
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
||||||
else:
|
else:
|
||||||
alibi = None
|
alibi = None
|
||||||
|
if position_ids is None:
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
else:
|
||||||
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
causal_mask = self._prepare_attn_mask(
|
causal_mask = self._prepare_attn_mask(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
@@ -922,6 +962,7 @@ class FalconModel(FalconPreTrainedModel):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
alibi,
|
alibi,
|
||||||
causal_mask,
|
causal_mask,
|
||||||
|
position_ids,
|
||||||
head_mask[i],
|
head_mask[i],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -929,6 +970,7 @@ class FalconModel(FalconPreTrainedModel):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
attention_mask=causal_mask,
|
attention_mask=causal_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
head_mask=head_mask[i],
|
head_mask=head_mask[i],
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
@@ -988,13 +1030,23 @@ class FalconForCausalLM(FalconPreTrainedModel):
|
|||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
past_key_values: Optional[torch.Tensor] = None,
|
past_key_values: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
input_ids = input_ids[:, -1:]
|
input_ids = input_ids[:, -1:]
|
||||||
|
|
||||||
|
# Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE.
|
||||||
|
if not self.transformer.use_alibi and attention_mask is not None and position_ids is None:
|
||||||
|
# create position_ids on the fly for batch generation
|
||||||
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
if past_key_values:
|
||||||
|
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
|
"position_ids": position_ids,
|
||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
"use_cache": kwargs.get("use_cache"),
|
"use_cache": kwargs.get("use_cache"),
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
@@ -1011,6 +1063,7 @@ class FalconForCausalLM(FalconPreTrainedModel):
|
|||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask: Optional[torch.Tensor] = None,
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
labels: Optional[torch.Tensor] = None,
|
labels: Optional[torch.Tensor] = None,
|
||||||
@@ -1032,6 +1085,7 @@ class FalconForCausalLM(FalconPreTrainedModel):
|
|||||||
input_ids,
|
input_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
|||||||
@@ -19,8 +19,16 @@ import unittest
|
|||||||
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import AutoConfig, AutoModel, AutoTokenizer, FalconConfig, is_torch_available, set_seed
|
from transformers import (
|
||||||
from transformers.testing_utils import CaptureLogger, require_torch, slow, tooslow, torch_device
|
AutoConfig,
|
||||||
|
AutoModel,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
FalconConfig,
|
||||||
|
is_torch_available,
|
||||||
|
set_seed,
|
||||||
|
)
|
||||||
|
from transformers.testing_utils import CaptureLogger, require_bitsandbytes, require_torch, slow, tooslow, torch_device
|
||||||
from transformers.utils import logging as transformers_logging
|
from transformers.utils import logging as transformers_logging
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
@@ -502,6 +510,35 @@ class FalconLanguageGenerationTest(unittest.TestCase):
|
|||||||
outputs_cache = model.generate(**inputs, do_sample=False, max_new_tokens=20, use_cache=True)
|
outputs_cache = model.generate(**inputs, do_sample=False, max_new_tokens=20, use_cache=True)
|
||||||
self.assertTrue((outputs_cache - outputs_no_cache).sum().item() == 0)
|
self.assertTrue((outputs_cache - outputs_no_cache).sum().item() == 0)
|
||||||
|
|
||||||
|
@require_bitsandbytes
|
||||||
|
@slow
|
||||||
|
def test_batched_generation(self):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b", padding_side="left")
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
"tiiuae/falcon-7b",
|
||||||
|
device_map="auto",
|
||||||
|
load_in_4bit=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_text = "A sequence: 1, 2" # should generate the rest of the sequence
|
||||||
|
|
||||||
|
unpadded_inputs = tokenizer([test_text], return_tensors="pt").to("cuda:0")
|
||||||
|
unpadded_inputs.pop("token_type_ids")
|
||||||
|
unpadded_gen_out = model.generate(**unpadded_inputs, max_new_tokens=20)
|
||||||
|
unpadded_gen_text = tokenizer.batch_decode(unpadded_gen_out, skip_special_tokens=True)
|
||||||
|
|
||||||
|
dummy_text = "This is a longer text " * 2 # forces left-padding on `test_text`
|
||||||
|
padded_inputs = tokenizer([test_text, dummy_text], return_tensors="pt", padding=True).to("cuda:0")
|
||||||
|
padded_inputs.pop("token_type_ids")
|
||||||
|
padded_gen_out = model.generate(**padded_inputs, max_new_tokens=20)
|
||||||
|
padded_gen_text = tokenizer.batch_decode(padded_gen_out, skip_special_tokens=True)
|
||||||
|
|
||||||
|
expected_output = "A sequence: 1, 2, 3, 4, 5, 6, 7, 8, "
|
||||||
|
self.assertLess(unpadded_inputs.input_ids.shape[-1], padded_inputs.input_ids.shape[-1]) # left-padding exists
|
||||||
|
self.assertEqual(unpadded_gen_text[0], expected_output)
|
||||||
|
self.assertEqual(padded_gen_text[0], expected_output)
|
||||||
|
|
||||||
|
|
||||||
# TODO Lysandre: Remove this in version v4.34
|
# TODO Lysandre: Remove this in version v4.34
|
||||||
class FalconOverrideTest(unittest.TestCase):
|
class FalconOverrideTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user