From a796f7eea6c86b54671a6f522cebbe41f630bb62 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 13 Sep 2023 17:00:52 +0100 Subject: [PATCH] Falcon: batched generation (#26137) --- .../models/falcon/modeling_falcon.py | 78 ++++++++++++++++--- tests/models/falcon/test_modeling_falcon.py | 41 +++++++++- 2 files changed, 105 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 7677ef64dd..c541fab0a2 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -67,6 +67,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) +# TODO (joao): Is this the same implementation as in Llama? If so, let's make them the same and add the copy facilities class FalconRotaryEmbedding(nn.Module): """Implementation of RotaryEmbedding from GPT-NeoX. This implementation is designed to operate on queries and keys that are compatible with `[batch_size, @@ -99,19 +100,40 @@ class FalconRotaryEmbedding(nn.Module): self.cos_cached = self.cos_cached.type(dtype) self.sin_cached = self.sin_cached.type(dtype) - def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor: + def cos_sin( + self, seq_len: int, past_key_values_length: int, position_ids: torch.Tensor, device="cpu", dtype=torch.bfloat16 + ) -> torch.Tensor: total_length = seq_len + past_key_values_length if total_length > self.seq_len_cached: self._set_cos_sin_cache(total_length, device, dtype) - return ( - self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length], - self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length], - ) + # Gather cos, sin at the designated position ids + cos = self.cos_cached.squeeze(0)[position_ids] # [bs, seq_len, dim] + sin = self.sin_cached.squeeze(0)[position_ids] # [bs, seq_len, dim] + return cos, sin - def forward(self, query, key, past_key_values_length=0): - batch, seq_len, head_dim = query.shape - cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype) - return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin) + def forward(self, query, key, past_key_values_length, position_ids): + _, seq_len, _ = query.shape + cos, sin = self.cos_sin(seq_len, past_key_values_length, position_ids, query.device, query.dtype) + # Query and key's shapes are [bs * num_heads, seq_len, dim], might need manual expansion. Ifs and elses used to + # avoid unnecessary repeat_interleave operations. + query_expansion_factor = int(query.shape[0] / cos.shape[0]) + if query_expansion_factor > 1: + query_cos = torch.repeat_interleave(cos, query_expansion_factor, dim=0) + query_sin = torch.repeat_interleave(sin, query_expansion_factor, dim=0) + else: + query_cos, query_sin = cos, sin + + key_expansion_factor = int(key.shape[0] / cos.shape[0]) + if key_expansion_factor > 1: + if key_expansion_factor != query_expansion_factor: + key_cos = torch.repeat_interleave(cos, key_expansion_factor, dim=0) + key_sin = torch.repeat_interleave(sin, key_expansion_factor, dim=0) + else: + key_cos, key_sin = query_cos, query_sin + else: + key_cos, key_sin = cos, sin + + return (query * query_cos) + (rotate_half(query) * query_sin), (key * key_cos) + (rotate_half(key) * key_sin) class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding): @@ -270,7 +292,7 @@ class FalconAttention(nn.Module): f" {self.num_heads})." ) - self.maybe_rotary = self._init_rope() if config.rotary else lambda q, k, t: (q, k) + self.maybe_rotary = self._init_rope() if config.rotary else lambda q, k, t, p: (q, k) # Layer-wise attention scaling self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) @@ -378,6 +400,7 @@ class FalconAttention(nn.Module): hidden_states: torch.Tensor, alibi: Optional[torch.Tensor], attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, @@ -399,7 +422,7 @@ class FalconAttention(nn.Module): value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim) past_kv_length = 0 if layer_past is None else layer_past[0].shape[1] - query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length) + query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length, position_ids) if layer_past is not None: past_key, past_value = layer_past @@ -415,7 +438,8 @@ class FalconAttention(nn.Module): else: present = None - attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype) + float_min = torch.finfo(query_layer.dtype).min + attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float_min).to(query_layer.dtype) query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim) key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) @@ -536,6 +560,7 @@ class FalconDecoderLayer(nn.Module): hidden_states: torch.Tensor, alibi: Optional[torch.Tensor], attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, @@ -554,6 +579,7 @@ class FalconDecoderLayer(nn.Module): attention_layernorm_out, layer_past=layer_past, attention_mask=attention_mask, + position_ids=position_ids, alibi=alibi, head_mask=head_mask, use_cache=use_cache, @@ -632,6 +658,11 @@ FALCON_INPUTS_DOCSTRING = r""" - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: @@ -836,6 +867,7 @@ class FalconModel(FalconPreTrainedModel): input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -892,6 +924,14 @@ class FalconModel(FalconPreTrainedModel): alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) else: alibi = None + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() causal_mask = self._prepare_attn_mask( attention_mask, @@ -922,6 +962,7 @@ class FalconModel(FalconPreTrainedModel): hidden_states, alibi, causal_mask, + position_ids, head_mask[i], ) else: @@ -929,6 +970,7 @@ class FalconModel(FalconPreTrainedModel): hidden_states, layer_past=layer_past, attention_mask=causal_mask, + position_ids=position_ids, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, @@ -988,13 +1030,23 @@ class FalconForCausalLM(FalconPreTrainedModel): input_ids: torch.LongTensor, past_key_values: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, **kwargs, ) -> dict: if past_key_values is not None: input_ids = input_ids[:, -1:] + # Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE. + if not self.transformer.use_alibi and attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + return { "input_ids": input_ids, + "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, @@ -1011,6 +1063,7 @@ class FalconForCausalLM(FalconPreTrainedModel): input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, @@ -1032,6 +1085,7 @@ class FalconForCausalLM(FalconPreTrainedModel): input_ids, past_key_values=past_key_values, attention_mask=attention_mask, + position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/tests/models/falcon/test_modeling_falcon.py b/tests/models/falcon/test_modeling_falcon.py index c5f7f7a8f9..81f1b1511e 100644 --- a/tests/models/falcon/test_modeling_falcon.py +++ b/tests/models/falcon/test_modeling_falcon.py @@ -19,8 +19,16 @@ import unittest from parameterized import parameterized -from transformers import AutoConfig, AutoModel, AutoTokenizer, FalconConfig, is_torch_available, set_seed -from transformers.testing_utils import CaptureLogger, require_torch, slow, tooslow, torch_device +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForCausalLM, + AutoTokenizer, + FalconConfig, + is_torch_available, + set_seed, +) +from transformers.testing_utils import CaptureLogger, require_bitsandbytes, require_torch, slow, tooslow, torch_device from transformers.utils import logging as transformers_logging from ...generation.test_utils import GenerationTesterMixin @@ -502,6 +510,35 @@ class FalconLanguageGenerationTest(unittest.TestCase): outputs_cache = model.generate(**inputs, do_sample=False, max_new_tokens=20, use_cache=True) self.assertTrue((outputs_cache - outputs_no_cache).sum().item() == 0) + @require_bitsandbytes + @slow + def test_batched_generation(self): + tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b", padding_side="left") + tokenizer.pad_token = tokenizer.eos_token + model = AutoModelForCausalLM.from_pretrained( + "tiiuae/falcon-7b", + device_map="auto", + load_in_4bit=True, + ) + + test_text = "A sequence: 1, 2" # should generate the rest of the sequence + + unpadded_inputs = tokenizer([test_text], return_tensors="pt").to("cuda:0") + unpadded_inputs.pop("token_type_ids") + unpadded_gen_out = model.generate(**unpadded_inputs, max_new_tokens=20) + unpadded_gen_text = tokenizer.batch_decode(unpadded_gen_out, skip_special_tokens=True) + + dummy_text = "This is a longer text " * 2 # forces left-padding on `test_text` + padded_inputs = tokenizer([test_text, dummy_text], return_tensors="pt", padding=True).to("cuda:0") + padded_inputs.pop("token_type_ids") + padded_gen_out = model.generate(**padded_inputs, max_new_tokens=20) + padded_gen_text = tokenizer.batch_decode(padded_gen_out, skip_special_tokens=True) + + expected_output = "A sequence: 1, 2, 3, 4, 5, 6, 7, 8, " + self.assertLess(unpadded_inputs.input_ids.shape[-1], padded_inputs.input_ids.shape[-1]) # left-padding exists + self.assertEqual(unpadded_gen_text[0], expected_output) + self.assertEqual(padded_gen_text[0], expected_output) + # TODO Lysandre: Remove this in version v4.34 class FalconOverrideTest(unittest.TestCase):