Generate: support for left-padding on GPTNeoX and Llama (#22382)
This commit is contained in:
@@ -100,12 +100,13 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: torch.FloatTensor,
|
||||||
attention_mask,
|
attention_mask: torch.FloatTensor,
|
||||||
head_mask=None,
|
position_ids: torch.LongTensor,
|
||||||
layer_past=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
use_cache=False,
|
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions=False,
|
use_cache: Optional[bool] = False,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
has_layer_past = layer_past is not None
|
has_layer_past = layer_past is not None
|
||||||
|
|
||||||
@@ -132,12 +133,10 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
|
|
||||||
# Compute token offset for rotary embeddings (when decoding)
|
# Compute token offset for rotary embeddings (when decoding)
|
||||||
seq_len = key.shape[-2]
|
seq_len = key.shape[-2]
|
||||||
offset = 0
|
|
||||||
if has_layer_past:
|
if has_layer_past:
|
||||||
offset = layer_past[0].shape[-2]
|
seq_len += layer_past[0].shape[-2]
|
||||||
seq_len += offset
|
|
||||||
cos, sin = self.rotary_emb(value, seq_len=seq_len)
|
cos, sin = self.rotary_emb(value, seq_len=seq_len)
|
||||||
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, offset=offset)
|
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
||||||
query = torch.cat((query, query_pass), dim=-1)
|
query = torch.cat((query, query_pass), dim=-1)
|
||||||
key = torch.cat((key, key_pass), dim=-1)
|
key = torch.cat((key, key_pass), dim=-1)
|
||||||
|
|
||||||
@@ -275,9 +274,11 @@ def rotate_half(x):
|
|||||||
return torch.cat((-x2, x1), dim=-1)
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||||
cos = cos[..., offset : q.shape[-2] + offset, :]
|
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
|
||||||
sin = sin[..., offset : q.shape[-2] + offset, :]
|
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
|
||||||
|
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
||||||
|
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
||||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
@@ -308,16 +309,18 @@ class GPTNeoXLayer(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states: Optional[torch.FloatTensor],
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask=None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
use_cache=False,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
layer_past=None,
|
use_cache: Optional[bool] = False,
|
||||||
output_attentions=False,
|
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
attention_layer_outputs = self.attention(
|
attention_layer_outputs = self.attention(
|
||||||
self.input_layernorm(hidden_states),
|
self.input_layernorm(hidden_states),
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -374,6 +377,11 @@ GPT_NEOX_INPUTS_DOCSTRING = r"""
|
|||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
[What are attention masks?](../glossary#attention-mask)
|
[What are attention masks?](../glossary#attention-mask)
|
||||||
|
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
||||||
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||||
|
config.n_positions - 1]`.
|
||||||
|
|
||||||
|
[What are position IDs?](../glossary#position-ids)
|
||||||
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
||||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
@@ -430,6 +438,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
@@ -467,7 +476,17 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
|||||||
batch_size, seq_length = input_shape
|
batch_size, seq_length = input_shape
|
||||||
|
|
||||||
if past_key_values is None:
|
if past_key_values is None:
|
||||||
|
past_length = 0
|
||||||
past_key_values = tuple([None] * self.config.num_hidden_layers)
|
past_key_values = tuple([None] * self.config.num_hidden_layers)
|
||||||
|
else:
|
||||||
|
past_length = past_key_values[0][0].size(-2)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
else:
|
||||||
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
# Attention mask.
|
# Attention mask.
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
@@ -527,12 +546,14 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
|||||||
create_custom_forward(layer),
|
create_custom_forward(layer),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
head_mask[i],
|
head_mask[i],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
outputs = layer(
|
outputs = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
head_mask=head_mask[i],
|
head_mask=head_mask[i],
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -587,6 +608,7 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
@@ -640,6 +662,7 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
|
|||||||
outputs = self.gpt_neox(
|
outputs = self.gpt_neox(
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
@@ -672,20 +695,29 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel):
|
|||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs):
|
||||||
input_shape = input_ids.shape
|
input_shape = input_ids.shape
|
||||||
|
|
||||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = input_ids.new_ones(input_shape)
|
|
||||||
|
|
||||||
# cut decoder_input_ids if past is used
|
# cut decoder_input_ids if past is used
|
||||||
if past_key_values and past_key_values[0] is not None:
|
if past_key_values and past_key_values[0] is not None:
|
||||||
input_ids = input_ids[:, -1:]
|
input_ids = input_ids[:, -1:]
|
||||||
|
|
||||||
|
position_ids = kwargs.get("position_ids", None)
|
||||||
|
if attention_mask is not None and position_ids is None:
|
||||||
|
# create position_ids on the fly for batch generation
|
||||||
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
if past_key_values:
|
||||||
|
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||||
|
|
||||||
|
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = input_ids.new_ones(input_shape)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
|
"position_ids": position_ids,
|
||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -192,7 +192,7 @@ class GPTJAttention(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: Optional[torch.FloatTensor],
|
hidden_states: torch.FloatTensor,
|
||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ logger = logging.get_logger(__name__)
|
|||||||
_CONFIG_FOR_DOC = "LlamaConfig"
|
_CONFIG_FOR_DOC = "LlamaConfig"
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
||||||
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
|
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
|
||||||
"""
|
"""
|
||||||
Make causal mask used for bi-directional self-attention.
|
Make causal mask used for bi-directional self-attention.
|
||||||
@@ -53,6 +54,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
|
|||||||
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
||||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||||
"""
|
"""
|
||||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||||
@@ -126,9 +128,11 @@ def rotate_half(x):
|
|||||||
return torch.cat((-x2, x1), dim=-1)
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||||
cos = cos[..., offset : q.shape[-2] + offset, :]
|
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
|
||||||
sin = sin[..., offset : q.shape[-2] + offset, :]
|
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
|
||||||
|
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
||||||
|
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
||||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
@@ -197,13 +201,12 @@ class LlamaAttention(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
@@ -211,12 +214,10 @@ class LlamaAttention(nn.Module):
|
|||||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
offset = 0
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
offset = past_key_value[0].shape[-2]
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
kv_seq_len += offset
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, offset=offset)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
# [bsz, nh, t, hd]
|
# [bsz, nh, t, hd]
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
@@ -283,9 +284,10 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -308,8 +310,9 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
# Self Attention
|
# Self Attention
|
||||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
past_key_value=past_key_value,
|
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
)
|
)
|
||||||
@@ -406,7 +409,11 @@ LLAMA_INPUTS_DOCSTRING = r"""
|
|||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the head is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||||
|
config.n_positions - 1]`.
|
||||||
|
|
||||||
|
[What are position IDs?](../glossary#position-ids)
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||||
@@ -488,10 +495,12 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
|
|
||||||
return combined_attention_mask
|
return combined_attention_mask
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
@@ -499,49 +508,6 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
||||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
|
||||||
provide it.
|
|
||||||
|
|
||||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
||||||
[`PreTrainedTokenizer.__call__`] for details.
|
|
||||||
|
|
||||||
[What are input IDs?](../glossary#input-ids)
|
|
||||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
||||||
|
|
||||||
- 1 for tokens that are **not masked**,
|
|
||||||
- 0 for tokens that are **masked**.
|
|
||||||
|
|
||||||
[What are attention masks?](../glossary#attention-mask)
|
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
||||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
|
||||||
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
|
||||||
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
|
||||||
|
|
||||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
|
||||||
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
|
||||||
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
|
||||||
use_cache (`bool`, *optional*):
|
|
||||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
|
||||||
(see `past_key_values`).
|
|
||||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
||||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
|
||||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
|
||||||
than the model's internal embedding lookup matrix.
|
|
||||||
output_attentions (`bool`, *optional*):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
||||||
returned tensors for more detail.
|
|
||||||
output_hidden_states (`bool`, *optional*):
|
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
||||||
for more detail.
|
|
||||||
return_dict (`bool`, *optional*):
|
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
||||||
"""
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
@@ -559,11 +525,23 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||||
|
|
||||||
seq_length_with_past = seq_length
|
seq_length_with_past = seq_length
|
||||||
past_key_values_length = 0
|
past_key_values_length = 0
|
||||||
|
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
past_key_values_length = past_key_values[0][0].shape[2]
|
past_key_values_length = past_key_values[0][0].shape[2]
|
||||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
else:
|
||||||
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
# embed positions
|
# embed positions
|
||||||
@@ -608,12 +586,14 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
create_custom_forward(decoder_layer),
|
create_custom_forward(decoder_layer),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -674,11 +654,13 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|||||||
def get_decoder(self):
|
def get_decoder(self):
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
@@ -689,52 +671,10 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
||||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
|
||||||
provide it.
|
|
||||||
|
|
||||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
||||||
[`PreTrainedTokenizer.__call__`] for details.
|
|
||||||
|
|
||||||
[What are input IDs?](../glossary#input-ids)
|
|
||||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
||||||
|
|
||||||
- 1 for tokens that are **not masked**,
|
|
||||||
- 0 for tokens that are **masked**.
|
|
||||||
|
|
||||||
[What are attention masks?](../glossary#attention-mask)
|
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
||||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
|
||||||
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
|
||||||
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
|
|
||||||
tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
|
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
|
||||||
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
|
||||||
|
|
||||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
|
||||||
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
|
||||||
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
|
||||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
||||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
|
||||||
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
|
||||||
than the model's internal embedding lookup matrix.
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
use_cache (`bool`, *optional*):
|
|
||||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
|
||||||
(see `past_key_values`).
|
|
||||||
output_attentions (`bool`, *optional*):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
||||||
returned tensors for more detail.
|
|
||||||
output_hidden_states (`bool`, *optional*):
|
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
||||||
for more detail.
|
|
||||||
return_dict (`bool`, *optional*):
|
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
@@ -765,6 +705,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -807,6 +748,14 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|||||||
if past_key_values:
|
if past_key_values:
|
||||||
input_ids = input_ids[:, -1:]
|
input_ids = input_ids[:, -1:]
|
||||||
|
|
||||||
|
position_ids = kwargs.get("position_ids", None)
|
||||||
|
if attention_mask is not None and position_ids is None:
|
||||||
|
# create position_ids on the fly for batch generation
|
||||||
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
if past_key_values:
|
||||||
|
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||||
|
|
||||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||||
if inputs_embeds is not None and past_key_values is None:
|
if inputs_embeds is not None and past_key_values is None:
|
||||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||||
@@ -815,6 +764,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|||||||
|
|
||||||
model_inputs.update(
|
model_inputs.update(
|
||||||
{
|
{
|
||||||
|
"position_ids": position_ids,
|
||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
"use_cache": kwargs.get("use_cache"),
|
"use_cache": kwargs.get("use_cache"),
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
@@ -868,6 +818,7 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
@@ -886,8 +837,9 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
|||||||
|
|
||||||
transformer_outputs = self.model(
|
transformer_outputs = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
past_key_values=past_key_values,
|
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
|||||||
@@ -237,7 +237,7 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
@require_torch
|
@require_torch
|
||||||
class GPTNeoXLanguageGenerationTest(unittest.TestCase):
|
class GPTNeoXLanguageGenerationTest(unittest.TestCase):
|
||||||
@slow
|
@slow
|
||||||
def test_lm_generate_codegen(self):
|
def test_lm_generate_gptneox(self):
|
||||||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped")
|
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped")
|
||||||
for checkpointing in [True, False]:
|
for checkpointing in [True, False]:
|
||||||
model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-410m-deduped")
|
model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-410m-deduped")
|
||||||
|
|||||||
Reference in New Issue
Block a user