|
|
|
|
@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.fx
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
import torch.utils.checkpoint
|
|
|
|
|
from torch import nn
|
|
|
|
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
|
|
@@ -35,6 +36,8 @@ from ...utils import (
|
|
|
|
|
add_code_sample_docstrings,
|
|
|
|
|
add_start_docstrings,
|
|
|
|
|
add_start_docstrings_to_model_forward,
|
|
|
|
|
is_flash_attn_2_available,
|
|
|
|
|
is_flash_attn_greater_or_equal_2_10,
|
|
|
|
|
is_torch_fx_proxy,
|
|
|
|
|
logging,
|
|
|
|
|
)
|
|
|
|
|
@@ -42,6 +45,11 @@ from ...utils.model_parallel_utils import assert_device_map, get_device_map
|
|
|
|
|
from .configuration_gptj import GPTJConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if is_flash_attn_2_available():
|
|
|
|
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
|
|
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
_CHECKPOINT_FOR_DOC = "hf-internal-testing/tiny-random-gptj"
|
|
|
|
|
@@ -55,6 +63,19 @@ GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
|
|
|
|
def _get_unpad_data(attention_mask):
|
|
|
|
|
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
|
|
|
|
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
|
|
|
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
|
|
|
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
|
|
|
|
return (
|
|
|
|
|
indices,
|
|
|
|
|
cu_seqlens,
|
|
|
|
|
max_seqlen_in_batch,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
|
|
|
|
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
|
|
|
|
|
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
|
|
|
|
|
@@ -82,7 +103,7 @@ def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Ten
|
|
|
|
|
class GPTJAttention(nn.Module):
|
|
|
|
|
def __init__(self, config):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
self.config = config
|
|
|
|
|
max_positions = config.max_position_embeddings
|
|
|
|
|
self.register_buffer(
|
|
|
|
|
"bias",
|
|
|
|
|
@@ -96,6 +117,8 @@ class GPTJAttention(nn.Module):
|
|
|
|
|
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
|
|
|
|
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
|
|
|
|
|
|
|
|
|
self.is_causal = True
|
|
|
|
|
|
|
|
|
|
self.embed_dim = config.hidden_size
|
|
|
|
|
self.num_attention_heads = config.num_attention_heads
|
|
|
|
|
self.head_dim = self.embed_dim // self.num_attention_heads
|
|
|
|
|
@@ -269,6 +292,256 @@ class GPTJAttention(nn.Module):
|
|
|
|
|
return outputs # a, present, (attentions)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GPTJFlashAttention2(GPTJAttention):
|
|
|
|
|
"""
|
|
|
|
|
GPTJ flash attention module. This module inherits from `GPTJAttention` as the weights of the module stays
|
|
|
|
|
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
|
|
|
|
flash attention and deal with padding tokens in case the input contains any of them.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
|
|
|
|
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
|
|
|
|
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
|
|
|
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
hidden_states: torch.FloatTensor,
|
|
|
|
|
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
|
|
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
|
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
|
|
|
head_mask: Optional[torch.FloatTensor] = None,
|
|
|
|
|
use_cache: Optional[bool] = False,
|
|
|
|
|
output_attentions: Optional[bool] = False,
|
|
|
|
|
) -> Union[
|
|
|
|
|
Tuple[torch.Tensor, Tuple[torch.Tensor]],
|
|
|
|
|
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
|
|
|
|
|
]:
|
|
|
|
|
query = self.q_proj(hidden_states)
|
|
|
|
|
key = self.k_proj(hidden_states)
|
|
|
|
|
value = self.v_proj(hidden_states)
|
|
|
|
|
|
|
|
|
|
query = self._split_heads(query, self.num_attention_heads, self.head_dim, True)
|
|
|
|
|
key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
|
|
|
|
|
value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)
|
|
|
|
|
|
|
|
|
|
if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing():
|
|
|
|
|
# The logic to conditionally copy to GPU could not be traced, so we do this
|
|
|
|
|
# every time in the torch.fx case
|
|
|
|
|
embed_positions = get_embed_positions(self.embed_positions, position_ids)
|
|
|
|
|
else:
|
|
|
|
|
embed_positions = self._get_embed_positions(position_ids)
|
|
|
|
|
|
|
|
|
|
repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
|
|
|
|
|
sincos = torch.gather(embed_positions, 1, repeated_position_ids)
|
|
|
|
|
sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
|
|
|
|
|
|
|
|
|
|
if self.rotary_dim is not None:
|
|
|
|
|
k_rot = key[:, :, :, : self.rotary_dim]
|
|
|
|
|
k_pass = key[:, :, :, self.rotary_dim :]
|
|
|
|
|
|
|
|
|
|
q_rot = query[:, :, :, : self.rotary_dim]
|
|
|
|
|
q_pass = query[:, :, :, self.rotary_dim :]
|
|
|
|
|
|
|
|
|
|
k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
|
|
|
|
|
q_rot = apply_rotary_pos_emb(q_rot, sin, cos)
|
|
|
|
|
|
|
|
|
|
key = torch.cat([k_rot, k_pass], dim=-1)
|
|
|
|
|
query = torch.cat([q_rot, q_pass], dim=-1)
|
|
|
|
|
else:
|
|
|
|
|
key = apply_rotary_pos_emb(key, sin, cos)
|
|
|
|
|
query = apply_rotary_pos_emb(query, sin, cos)
|
|
|
|
|
|
|
|
|
|
# tanspose to have the desired shape
|
|
|
|
|
# before transpose: batch_size x seq_length x num_attention_heads x head_dim
|
|
|
|
|
# after transpose: batch_size x num_attention_heads x seq_length x head_dim
|
|
|
|
|
key = key.permute(0, 2, 1, 3)
|
|
|
|
|
query = query.permute(0, 2, 1, 3)
|
|
|
|
|
# value: batch_size x num_attention_heads x seq_length x head_dim
|
|
|
|
|
|
|
|
|
|
if layer_past is not None:
|
|
|
|
|
past_key = layer_past[0]
|
|
|
|
|
past_value = layer_past[1]
|
|
|
|
|
key = torch.cat((past_key, key), dim=-2)
|
|
|
|
|
value = torch.cat((past_value, value), dim=-2)
|
|
|
|
|
|
|
|
|
|
if use_cache is True:
|
|
|
|
|
# Note that this cast is quite ugly, but is not implemented before ROPE as the original codebase keeps the key in float32 all along the computation.
|
|
|
|
|
# Reference: https://github.com/kingoflolz/mesh-transformer-jax/blob/f8315e3003033b23f21d78361b288953064e0e76/mesh_transformer/layers.py#L128
|
|
|
|
|
present = (key.to(hidden_states.dtype), value)
|
|
|
|
|
else:
|
|
|
|
|
present = None
|
|
|
|
|
|
|
|
|
|
# The Flash attention requires the input to have the shape
|
|
|
|
|
# batch_size x seq_length x head_dim x hidden_dim
|
|
|
|
|
# therefore we need to keep the original shape for query and key, and reshape value
|
|
|
|
|
# to have the correct shape.
|
|
|
|
|
key = key.permute(0, 2, 1, 3).contiguous()
|
|
|
|
|
query = query.permute(0, 2, 1, 3).contiguous()
|
|
|
|
|
value = value.permute(0, 2, 1, 3).contiguous()
|
|
|
|
|
|
|
|
|
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
|
|
|
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
|
|
|
|
# cast them back in the correct dtype just to be sure everything works as expected.
|
|
|
|
|
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
|
|
|
|
# in fp32. (LlamaRMSNorm handles it correctly)
|
|
|
|
|
|
|
|
|
|
input_dtype = query.dtype
|
|
|
|
|
if input_dtype == torch.float32:
|
|
|
|
|
if torch.is_autocast_enabled():
|
|
|
|
|
target_dtype = torch.get_autocast_gpu_dtype()
|
|
|
|
|
# Handle the case where the model is quantized
|
|
|
|
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
|
|
|
|
target_dtype = self.config._pre_quantization_dtype
|
|
|
|
|
else:
|
|
|
|
|
target_dtype = self.q_proj.weight.dtype
|
|
|
|
|
|
|
|
|
|
logger.warning_once(
|
|
|
|
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
|
|
|
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
|
|
|
|
f" {target_dtype}."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
query = query.to(target_dtype)
|
|
|
|
|
key = key.to(target_dtype)
|
|
|
|
|
value = value.to(target_dtype)
|
|
|
|
|
|
|
|
|
|
attention_dropout = self.config.attn_pdrop if self.training else 0.0 # attn_pdrop in gptj
|
|
|
|
|
|
|
|
|
|
query_length = query.shape[1]
|
|
|
|
|
|
|
|
|
|
# Compute attention
|
|
|
|
|
attn_weights = self._flash_attention_forward(
|
|
|
|
|
query,
|
|
|
|
|
key,
|
|
|
|
|
value,
|
|
|
|
|
attention_mask,
|
|
|
|
|
query_length,
|
|
|
|
|
dropout=attention_dropout,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Reshape outputs
|
|
|
|
|
attn_output = attn_weights.reshape(
|
|
|
|
|
attn_weights.shape[0], attn_weights.shape[1], attn_weights.shape[2] * attn_weights.shape[3]
|
|
|
|
|
)
|
|
|
|
|
attn_output = self.out_proj(attn_output)
|
|
|
|
|
attn_output = self.resid_dropout(attn_output)
|
|
|
|
|
|
|
|
|
|
outputs = (attn_output, present)
|
|
|
|
|
if output_attentions:
|
|
|
|
|
outputs += (attn_weights,)
|
|
|
|
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
|
|
|
|
|
def _flash_attention_forward(
|
|
|
|
|
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
|
|
|
|
first unpad the input, then computes the attention scores and pad the final attention scores.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
query_states (`torch.Tensor`):
|
|
|
|
|
Input query states to be passed to Flash Attention API
|
|
|
|
|
key_states (`torch.Tensor`):
|
|
|
|
|
Input key states to be passed to Flash Attention API
|
|
|
|
|
value_states (`torch.Tensor`):
|
|
|
|
|
Input value states to be passed to Flash Attention API
|
|
|
|
|
attention_mask (`torch.Tensor`):
|
|
|
|
|
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
|
|
|
|
position of padding tokens and 1 for the position of non-padding tokens.
|
|
|
|
|
dropout (`int`, *optional*):
|
|
|
|
|
Attention dropout
|
|
|
|
|
softmax_scale (`float`, *optional*):
|
|
|
|
|
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
|
|
|
|
"""
|
|
|
|
|
if not self._flash_attn_uses_top_left_mask:
|
|
|
|
|
causal = self.is_causal
|
|
|
|
|
else:
|
|
|
|
|
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
|
|
|
|
causal = self.is_causal and query_length != 1
|
|
|
|
|
|
|
|
|
|
# Contains at least one padding token in the sequence
|
|
|
|
|
if attention_mask is not None:
|
|
|
|
|
batch_size = query_states.shape[0]
|
|
|
|
|
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
|
|
|
|
query_states, key_states, value_states, attention_mask, query_length
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
|
|
|
|
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
|
|
|
|
|
|
|
|
|
attn_output_unpad = flash_attn_varlen_func(
|
|
|
|
|
query_states,
|
|
|
|
|
key_states,
|
|
|
|
|
value_states,
|
|
|
|
|
cu_seqlens_q=cu_seqlens_q,
|
|
|
|
|
cu_seqlens_k=cu_seqlens_k,
|
|
|
|
|
max_seqlen_q=max_seqlen_in_batch_q,
|
|
|
|
|
max_seqlen_k=max_seqlen_in_batch_k,
|
|
|
|
|
dropout_p=dropout,
|
|
|
|
|
softmax_scale=softmax_scale,
|
|
|
|
|
causal=causal,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
|
|
|
|
else:
|
|
|
|
|
attn_output = flash_attn_func(
|
|
|
|
|
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return attn_output
|
|
|
|
|
|
|
|
|
|
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input with num_heads->num_attention_heads
|
|
|
|
|
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
|
|
|
|
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
|
|
|
|
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
|
|
|
|
|
|
|
|
|
key_layer = index_first_axis(
|
|
|
|
|
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
|
|
|
|
)
|
|
|
|
|
value_layer = index_first_axis(
|
|
|
|
|
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
|
|
|
|
)
|
|
|
|
|
if query_length == kv_seq_len:
|
|
|
|
|
query_layer = index_first_axis(
|
|
|
|
|
query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads, head_dim), indices_k
|
|
|
|
|
)
|
|
|
|
|
cu_seqlens_q = cu_seqlens_k
|
|
|
|
|
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
|
|
|
|
indices_q = indices_k
|
|
|
|
|
elif query_length == 1:
|
|
|
|
|
max_seqlen_in_batch_q = 1
|
|
|
|
|
cu_seqlens_q = torch.arange(
|
|
|
|
|
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
|
|
|
|
) # There is a memcpy here, that is very bad.
|
|
|
|
|
indices_q = cu_seqlens_q[:-1]
|
|
|
|
|
query_layer = query_layer.squeeze(1)
|
|
|
|
|
else:
|
|
|
|
|
# The -q_len: slice assumes left padding.
|
|
|
|
|
attention_mask = attention_mask[:, -query_length:]
|
|
|
|
|
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
|
|
|
|
|
|
|
|
|
return (
|
|
|
|
|
query_layer,
|
|
|
|
|
key_layer,
|
|
|
|
|
value_layer,
|
|
|
|
|
indices_q,
|
|
|
|
|
(cu_seqlens_q, cu_seqlens_k),
|
|
|
|
|
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GPTJ_ATTENTION_CLASSES = {
|
|
|
|
|
"eager": GPTJAttention,
|
|
|
|
|
"flash_attention_2": GPTJFlashAttention2,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GPTJMLP(nn.Module):
|
|
|
|
|
def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim
|
|
|
|
|
super().__init__()
|
|
|
|
|
@@ -293,7 +566,7 @@ class GPTJBlock(nn.Module):
|
|
|
|
|
super().__init__()
|
|
|
|
|
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
|
|
|
|
|
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
|
|
|
|
self.attn = GPTJAttention(config)
|
|
|
|
|
self.attn = GPTJ_ATTENTION_CLASSES[config._attn_implementation](config)
|
|
|
|
|
self.mlp = GPTJMLP(inner_dim, config)
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
@@ -343,6 +616,7 @@ class GPTJPreTrainedModel(PreTrainedModel):
|
|
|
|
|
supports_gradient_checkpointing = True
|
|
|
|
|
_no_split_modules = ["GPTJBlock"]
|
|
|
|
|
_skip_keys_device_placement = "past_key_values"
|
|
|
|
|
_supports_flash_attn_2 = True
|
|
|
|
|
|
|
|
|
|
def __init__(self, *inputs, **kwargs):
|
|
|
|
|
super().__init__(*inputs, **kwargs)
|
|
|
|
|
@@ -496,6 +770,8 @@ class GPTJModel(GPTJPreTrainedModel):
|
|
|
|
|
# Initialize weights and apply final processing
|
|
|
|
|
self.post_init()
|
|
|
|
|
|
|
|
|
|
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
|
|
|
|
|
|
|
|
|
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
|
|
|
|
def parallelize(self, device_map=None):
|
|
|
|
|
warnings.warn(
|
|
|
|
|
@@ -600,6 +876,7 @@ class GPTJModel(GPTJPreTrainedModel):
|
|
|
|
|
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
|
|
|
|
position_ids = position_ids.unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
if not self._use_flash_attention_2:
|
|
|
|
|
# Attention mask.
|
|
|
|
|
if attention_mask is not None:
|
|
|
|
|
if batch_size <= 0:
|
|
|
|
|
|