Phi3: fix attn for sliding window (#33586)
* fix phi3 attn fir sliding window * fix tests * address most comment * style * update after rebase * add more models * fix tests
This commit is contained in:
committed by
GitHub
parent
a265600c60
commit
adea67541a
@@ -23,7 +23,7 @@ import torch.utils.checkpoint
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_outputs import BaseModelOutputWithPast
|
from ...modeling_outputs import BaseModelOutputWithPast
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
@@ -1027,7 +1027,7 @@ class MimiTransformerModel(nn.Module):
|
|||||||
attentions=all_self_attns,
|
attentions=all_self_attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from transformers.models.gemma.modeling_gemma.GemmaModel._update_causal_mask
|
# Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask
|
||||||
def _update_causal_mask(
|
def _update_causal_mask(
|
||||||
self,
|
self,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
@@ -1046,21 +1046,30 @@ class MimiTransformerModel(nn.Module):
|
|||||||
# to infer the attention mask.
|
# to infer the attention mask.
|
||||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
|
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
|
||||||
|
|
||||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||||
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
if (
|
||||||
|
self.config._attn_implementation == "sdpa"
|
||||||
|
and not (using_static_cache or using_sliding_window_cache)
|
||||||
|
and not output_attentions
|
||||||
|
):
|
||||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
inputs_embeds=input_tensor,
|
inputs_embeds=input_tensor,
|
||||||
past_key_values_length=past_seen_tokens,
|
past_key_values_length=past_seen_tokens,
|
||||||
|
sliding_window=self.config.sliding_window,
|
||||||
is_training=self.training,
|
is_training=self.training,
|
||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
dtype, device = input_tensor.dtype, input_tensor.device
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
sequence_length = input_tensor.shape[1]
|
sequence_length = input_tensor.shape[1]
|
||||||
if using_static_cache:
|
# SlidingWindowCache or StaticCache
|
||||||
|
if using_sliding_window_cache or using_static_cache:
|
||||||
target_length = past_key_values.get_max_cache_shape()
|
target_length = past_key_values.get_max_cache_shape()
|
||||||
|
# DynamicCache or no cache
|
||||||
else:
|
else:
|
||||||
target_length = (
|
target_length = (
|
||||||
attention_mask.shape[-1]
|
attention_mask.shape[-1]
|
||||||
@@ -1077,6 +1086,8 @@ class MimiTransformerModel(nn.Module):
|
|||||||
device=device,
|
device=device,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
batch_size=input_tensor.shape[0],
|
batch_size=input_tensor.shape[0],
|
||||||
|
config=self.config,
|
||||||
|
past_key_values=past_key_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -1088,13 +1099,12 @@ class MimiTransformerModel(nn.Module):
|
|||||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
min_dtype = torch.finfo(dtype).min
|
|
||||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
# Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Mimi
|
||||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
sequence_length: int,
|
sequence_length: int,
|
||||||
@@ -1103,6 +1113,8 @@ class MimiTransformerModel(nn.Module):
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
cache_position: torch.Tensor,
|
cache_position: torch.Tensor,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
config: MimiConfig,
|
||||||
|
past_key_values: Cache,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
@@ -1110,13 +1122,11 @@ class MimiTransformerModel(nn.Module):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
attention_mask (`torch.Tensor`):
|
attention_mask (`torch.Tensor`):
|
||||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||||
`(batch_size, 1, query_length, key_value_length)`.
|
|
||||||
sequence_length (`int`):
|
sequence_length (`int`):
|
||||||
The sequence length being processed.
|
The sequence length being processed.
|
||||||
target_length (`int`):
|
target_length (`int`):
|
||||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
|
||||||
dtype (`torch.dtype`):
|
dtype (`torch.dtype`):
|
||||||
The dtype to use for the 4D attention mask.
|
The dtype to use for the 4D attention mask.
|
||||||
device (`torch.device`):
|
device (`torch.device`):
|
||||||
@@ -1125,6 +1135,10 @@ class MimiTransformerModel(nn.Module):
|
|||||||
Indices depicting the position of the input sequence tokens in the sequence.
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
batch_size (`torch.Tensor`):
|
batch_size (`torch.Tensor`):
|
||||||
Batch size.
|
Batch size.
|
||||||
|
config (`MimiConfig`):
|
||||||
|
The model's configuration class
|
||||||
|
past_key_values (`Cache`):
|
||||||
|
The cache class that is being used currently to generate
|
||||||
"""
|
"""
|
||||||
if attention_mask is not None and attention_mask.dim() == 4:
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
@@ -1134,19 +1148,27 @@ class MimiTransformerModel(nn.Module):
|
|||||||
causal_mask = torch.full(
|
causal_mask = torch.full(
|
||||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
if sequence_length != 1:
|
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
if config.sliding_window is not None:
|
||||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||||
|
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||||
|
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||||
|
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||||
|
cache_position.reshape(-1, 1) - config.sliding_window
|
||||||
|
)
|
||||||
|
diagonal_attend_mask |= sliding_attend_mask
|
||||||
|
causal_mask *= diagonal_attend_mask
|
||||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
if attention_mask.shape[-1] > target_length:
|
||||||
|
attention_mask = attention_mask[:, :target_length]
|
||||||
mask_length = attention_mask.shape[-1]
|
mask_length = attention_mask.shape[-1]
|
||||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
padding_mask = padding_mask == 0
|
padding_mask = padding_mask == 0
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
padding_mask, min_dtype
|
padding_mask, min_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -856,7 +856,7 @@ class MistralModel(MistralPreTrainedModel):
|
|||||||
use_cache: bool,
|
use_cache: bool,
|
||||||
output_attentions: bool,
|
output_attentions: bool,
|
||||||
):
|
):
|
||||||
if self._attn_implementation == "flash_attention_2":
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
if attention_mask is not None and use_cache:
|
if attention_mask is not None and use_cache:
|
||||||
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
|
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
|
||||||
if is_padding_right:
|
if is_padding_right:
|
||||||
@@ -872,12 +872,11 @@ class MistralModel(MistralPreTrainedModel):
|
|||||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||||
# to infer the attention mask.
|
# to infer the attention mask.
|
||||||
|
|
||||||
# cache_position must be valid here no matter which cache we use
|
|
||||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
|
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
|
||||||
|
|
||||||
|
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||||
if (
|
if (
|
||||||
self.config._attn_implementation == "sdpa"
|
self.config._attn_implementation == "sdpa"
|
||||||
and not (using_static_cache or using_sliding_window_cache)
|
and not (using_static_cache or using_sliding_window_cache)
|
||||||
@@ -906,29 +905,17 @@ class MistralModel(MistralPreTrainedModel):
|
|||||||
else past_seen_tokens + sequence_length + 1
|
else past_seen_tokens + sequence_length + 1
|
||||||
)
|
)
|
||||||
|
|
||||||
if attention_mask is not None and attention_mask.dim() == 4:
|
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||||
causal_mask = attention_mask
|
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
else:
|
attention_mask,
|
||||||
causal_mask = torch.full(
|
sequence_length=sequence_length,
|
||||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
target_length=target_length,
|
||||||
)
|
dtype=dtype,
|
||||||
exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
device=device,
|
||||||
if self.config.sliding_window is not None:
|
cache_position=cache_position,
|
||||||
if not using_sliding_window_cache or sequence_length > self.config.sliding_window:
|
batch_size=input_tensor.shape[0],
|
||||||
exclude_mask.bitwise_or_(
|
config=self.config,
|
||||||
torch.arange(target_length, device=device)
|
past_key_values=past_key_values,
|
||||||
<= (cache_position.reshape(-1, 1) - self.config.sliding_window)
|
|
||||||
)
|
|
||||||
causal_mask *= exclude_mask
|
|
||||||
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
|
||||||
if attention_mask is not None:
|
|
||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
|
||||||
if attention_mask.dim() == 2:
|
|
||||||
mask_length = attention_mask.shape[-1]
|
|
||||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
|
||||||
padding_mask = padding_mask == 0
|
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
|
||||||
padding_mask, min_dtype
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -944,6 +931,73 @@ class MistralModel(MistralPreTrainedModel):
|
|||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
sequence_length: int,
|
||||||
|
target_length: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
config: MistralConfig,
|
||||||
|
past_key_values: Cache,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||||
|
sequence_length (`int`):
|
||||||
|
The sequence length being processed.
|
||||||
|
target_length (`int`):
|
||||||
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
|
dtype (`torch.dtype`):
|
||||||
|
The dtype to use for the 4D attention mask.
|
||||||
|
device (`torch.device`):
|
||||||
|
The device to plcae the 4D attention mask on.
|
||||||
|
cache_position (`torch.Tensor`):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
|
batch_size (`torch.Tensor`):
|
||||||
|
Batch size.
|
||||||
|
config (`MistralConfig`):
|
||||||
|
The model's configuration class
|
||||||
|
past_key_values (`Cache`):
|
||||||
|
The cache class that is being used currently to generate
|
||||||
|
"""
|
||||||
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
|
causal_mask = attention_mask
|
||||||
|
else:
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
causal_mask = torch.full(
|
||||||
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
|
if config.sliding_window is not None:
|
||||||
|
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||||
|
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||||
|
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||||
|
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||||
|
cache_position.reshape(-1, 1) - config.sliding_window
|
||||||
|
)
|
||||||
|
diagonal_attend_mask |= sliding_attend_mask
|
||||||
|
causal_mask *= diagonal_attend_mask
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
if attention_mask.shape[-1] > target_length:
|
||||||
|
attention_mask = attention_mask[:, :target_length]
|
||||||
|
mask_length = attention_mask.shape[-1]
|
||||||
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
padding_mask, min_dtype
|
||||||
|
)
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin):
|
class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin):
|
||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
@@ -1074,6 +1128,78 @@ class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin):
|
|||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
past_key_values=None,
|
||||||
|
attention_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
cache_position=None,
|
||||||
|
position_ids=None,
|
||||||
|
use_cache=True,
|
||||||
|
num_logits_to_keep=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||||
|
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||||
|
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||||
|
if past_key_values is not None:
|
||||||
|
if inputs_embeds is not None: # Exception 1
|
||||||
|
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||||
|
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||||
|
input_ids = input_ids[:, cache_position]
|
||||||
|
|
||||||
|
if attention_mask is not None and position_ids is None:
|
||||||
|
# create position_ids on the fly for batch generation
|
||||||
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
if past_key_values:
|
||||||
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||||
|
|
||||||
|
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
||||||
|
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||||
|
if inputs_embeds is not None and cache_position[0] == 0:
|
||||||
|
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||||
|
else:
|
||||||
|
# `contiguous()` needed for compilation use cases
|
||||||
|
model_inputs = {"input_ids": input_ids.contiguous(), "inputs_embeds": None}
|
||||||
|
|
||||||
|
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||||
|
if model_inputs["inputs_embeds"] is not None:
|
||||||
|
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||||
|
device = model_inputs["inputs_embeds"].device
|
||||||
|
else:
|
||||||
|
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||||
|
device = model_inputs["input_ids"].device
|
||||||
|
|
||||||
|
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
target_length=past_key_values.get_max_cache_shape(),
|
||||||
|
dtype=self.lm_head.weight.dtype,
|
||||||
|
device=device,
|
||||||
|
cache_position=cache_position,
|
||||||
|
batch_size=batch_size,
|
||||||
|
config=self.config,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_logits_to_keep is not None:
|
||||||
|
model_inputs["num_logits_to_keep"] = num_logits_to_keep
|
||||||
|
|
||||||
|
model_inputs.update(
|
||||||
|
{
|
||||||
|
"position_ids": position_ids,
|
||||||
|
"cache_position": cache_position,
|
||||||
|
"past_key_values": past_key_values,
|
||||||
|
"use_cache": use_cache,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
|
from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
@@ -1068,7 +1068,7 @@ class MixtralModel(MixtralPreTrainedModel):
|
|||||||
router_logits=all_router_logits,
|
router_logits=all_router_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
# Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask
|
||||||
def _update_causal_mask(
|
def _update_causal_mask(
|
||||||
self,
|
self,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
@@ -1087,21 +1087,30 @@ class MixtralModel(MixtralPreTrainedModel):
|
|||||||
# to infer the attention mask.
|
# to infer the attention mask.
|
||||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
|
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
|
||||||
|
|
||||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||||
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
if (
|
||||||
|
self.config._attn_implementation == "sdpa"
|
||||||
|
and not (using_static_cache or using_sliding_window_cache)
|
||||||
|
and not output_attentions
|
||||||
|
):
|
||||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
inputs_embeds=input_tensor,
|
inputs_embeds=input_tensor,
|
||||||
past_key_values_length=past_seen_tokens,
|
past_key_values_length=past_seen_tokens,
|
||||||
|
sliding_window=self.config.sliding_window,
|
||||||
is_training=self.training,
|
is_training=self.training,
|
||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
dtype, device = input_tensor.dtype, input_tensor.device
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
sequence_length = input_tensor.shape[1]
|
sequence_length = input_tensor.shape[1]
|
||||||
if using_static_cache:
|
# SlidingWindowCache or StaticCache
|
||||||
|
if using_sliding_window_cache or using_static_cache:
|
||||||
target_length = past_key_values.get_max_cache_shape()
|
target_length = past_key_values.get_max_cache_shape()
|
||||||
|
# DynamicCache or no cache
|
||||||
else:
|
else:
|
||||||
target_length = (
|
target_length = (
|
||||||
attention_mask.shape[-1]
|
attention_mask.shape[-1]
|
||||||
@@ -1118,6 +1127,8 @@ class MixtralModel(MixtralPreTrainedModel):
|
|||||||
device=device,
|
device=device,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
batch_size=input_tensor.shape[0],
|
batch_size=input_tensor.shape[0],
|
||||||
|
config=self.config,
|
||||||
|
past_key_values=past_key_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -1129,13 +1140,12 @@ class MixtralModel(MixtralPreTrainedModel):
|
|||||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
min_dtype = torch.finfo(dtype).min
|
|
||||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
# Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Mixtral
|
||||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
sequence_length: int,
|
sequence_length: int,
|
||||||
@@ -1144,6 +1154,8 @@ class MixtralModel(MixtralPreTrainedModel):
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
cache_position: torch.Tensor,
|
cache_position: torch.Tensor,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
config: MixtralConfig,
|
||||||
|
past_key_values: Cache,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
@@ -1151,13 +1163,11 @@ class MixtralModel(MixtralPreTrainedModel):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
attention_mask (`torch.Tensor`):
|
attention_mask (`torch.Tensor`):
|
||||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||||
`(batch_size, 1, query_length, key_value_length)`.
|
|
||||||
sequence_length (`int`):
|
sequence_length (`int`):
|
||||||
The sequence length being processed.
|
The sequence length being processed.
|
||||||
target_length (`int`):
|
target_length (`int`):
|
||||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
|
||||||
dtype (`torch.dtype`):
|
dtype (`torch.dtype`):
|
||||||
The dtype to use for the 4D attention mask.
|
The dtype to use for the 4D attention mask.
|
||||||
device (`torch.device`):
|
device (`torch.device`):
|
||||||
@@ -1166,6 +1176,10 @@ class MixtralModel(MixtralPreTrainedModel):
|
|||||||
Indices depicting the position of the input sequence tokens in the sequence.
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
batch_size (`torch.Tensor`):
|
batch_size (`torch.Tensor`):
|
||||||
Batch size.
|
Batch size.
|
||||||
|
config (`MixtralConfig`):
|
||||||
|
The model's configuration class
|
||||||
|
past_key_values (`Cache`):
|
||||||
|
The cache class that is being used currently to generate
|
||||||
"""
|
"""
|
||||||
if attention_mask is not None and attention_mask.dim() == 4:
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
@@ -1175,19 +1189,27 @@ class MixtralModel(MixtralPreTrainedModel):
|
|||||||
causal_mask = torch.full(
|
causal_mask = torch.full(
|
||||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
if sequence_length != 1:
|
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
if config.sliding_window is not None:
|
||||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||||
|
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||||
|
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||||
|
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||||
|
cache_position.reshape(-1, 1) - config.sliding_window
|
||||||
|
)
|
||||||
|
diagonal_attend_mask |= sliding_attend_mask
|
||||||
|
causal_mask *= diagonal_attend_mask
|
||||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
if attention_mask.shape[-1] > target_length:
|
||||||
|
attention_mask = attention_mask[:, :target_length]
|
||||||
mask_length = attention_mask.shape[-1]
|
mask_length = attention_mask.shape[-1]
|
||||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
padding_mask = padding_mask == 0
|
padding_mask = padding_mask == 0
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
padding_mask, min_dtype
|
padding_mask, min_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
@@ -1344,6 +1366,76 @@ class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin):
|
|||||||
router_logits=outputs.router_logits,
|
router_logits=outputs.router_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
past_key_values=None,
|
||||||
|
attention_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
cache_position=None,
|
||||||
|
output_router_logits=False,
|
||||||
|
position_ids=None,
|
||||||
|
use_cache=True,
|
||||||
|
num_logits_to_keep=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||||
|
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||||
|
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||||
|
if past_key_values is not None:
|
||||||
|
if inputs_embeds is not None: # Exception 1
|
||||||
|
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||||
|
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||||
|
input_ids = input_ids[:, cache_position]
|
||||||
|
|
||||||
|
if attention_mask is not None and position_ids is None:
|
||||||
|
# create position_ids on the fly for batch generation
|
||||||
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
if past_key_values:
|
||||||
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||||
|
|
||||||
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||||
|
if inputs_embeds is not None and cache_position[0] == 0:
|
||||||
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||||
|
else:
|
||||||
|
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
||||||
|
|
||||||
|
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||||
|
if model_inputs["inputs_embeds"] is not None:
|
||||||
|
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||||
|
device = model_inputs["inputs_embeds"].device
|
||||||
|
else:
|
||||||
|
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||||
|
device = model_inputs["input_ids"].device
|
||||||
|
|
||||||
|
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
target_length=past_key_values.get_max_cache_shape(),
|
||||||
|
dtype=self.lm_head.weight.dtype,
|
||||||
|
device=device,
|
||||||
|
cache_position=cache_position,
|
||||||
|
batch_size=batch_size,
|
||||||
|
config=self.config,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_logits_to_keep is not None:
|
||||||
|
model_inputs["num_logits_to_keep"] = num_logits_to_keep
|
||||||
|
|
||||||
|
model_inputs.update(
|
||||||
|
{
|
||||||
|
"position_ids": position_ids,
|
||||||
|
"cache_position": cache_position,
|
||||||
|
"past_key_values": past_key_values,
|
||||||
|
"use_cache": use_cache,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"output_router_logits": output_router_logits,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
@@ -1036,7 +1036,6 @@ class Phi3Model(Phi3PreTrainedModel):
|
|||||||
attentions=all_self_attns,
|
attentions=all_self_attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
|
||||||
def _update_causal_mask(
|
def _update_causal_mask(
|
||||||
self,
|
self,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
@@ -1055,21 +1054,30 @@ class Phi3Model(Phi3PreTrainedModel):
|
|||||||
# to infer the attention mask.
|
# to infer the attention mask.
|
||||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
|
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
|
||||||
|
|
||||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||||
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
if (
|
||||||
|
self.config._attn_implementation == "sdpa"
|
||||||
|
and not (using_static_cache or using_sliding_window_cache)
|
||||||
|
and not output_attentions
|
||||||
|
):
|
||||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
inputs_embeds=input_tensor,
|
inputs_embeds=input_tensor,
|
||||||
past_key_values_length=past_seen_tokens,
|
past_key_values_length=past_seen_tokens,
|
||||||
|
sliding_window=self.config.sliding_window,
|
||||||
is_training=self.training,
|
is_training=self.training,
|
||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
dtype, device = input_tensor.dtype, input_tensor.device
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
sequence_length = input_tensor.shape[1]
|
sequence_length = input_tensor.shape[1]
|
||||||
if using_static_cache:
|
# SlidingWindowCache or StaticCache
|
||||||
|
if using_sliding_window_cache or using_static_cache:
|
||||||
target_length = past_key_values.get_max_cache_shape()
|
target_length = past_key_values.get_max_cache_shape()
|
||||||
|
# DynamicCache or no cache
|
||||||
else:
|
else:
|
||||||
target_length = (
|
target_length = (
|
||||||
attention_mask.shape[-1]
|
attention_mask.shape[-1]
|
||||||
@@ -1086,6 +1094,8 @@ class Phi3Model(Phi3PreTrainedModel):
|
|||||||
device=device,
|
device=device,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
batch_size=input_tensor.shape[0],
|
batch_size=input_tensor.shape[0],
|
||||||
|
config=self.config,
|
||||||
|
past_key_values=past_key_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -1097,13 +1107,12 @@ class Phi3Model(Phi3PreTrainedModel):
|
|||||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
min_dtype = torch.finfo(dtype).min
|
|
||||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
# Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Phi3
|
||||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
sequence_length: int,
|
sequence_length: int,
|
||||||
@@ -1112,6 +1121,8 @@ class Phi3Model(Phi3PreTrainedModel):
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
cache_position: torch.Tensor,
|
cache_position: torch.Tensor,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
config: Phi3Config,
|
||||||
|
past_key_values: Cache,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
@@ -1119,13 +1130,11 @@ class Phi3Model(Phi3PreTrainedModel):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
attention_mask (`torch.Tensor`):
|
attention_mask (`torch.Tensor`):
|
||||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||||
`(batch_size, 1, query_length, key_value_length)`.
|
|
||||||
sequence_length (`int`):
|
sequence_length (`int`):
|
||||||
The sequence length being processed.
|
The sequence length being processed.
|
||||||
target_length (`int`):
|
target_length (`int`):
|
||||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
|
||||||
dtype (`torch.dtype`):
|
dtype (`torch.dtype`):
|
||||||
The dtype to use for the 4D attention mask.
|
The dtype to use for the 4D attention mask.
|
||||||
device (`torch.device`):
|
device (`torch.device`):
|
||||||
@@ -1134,6 +1143,10 @@ class Phi3Model(Phi3PreTrainedModel):
|
|||||||
Indices depicting the position of the input sequence tokens in the sequence.
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
batch_size (`torch.Tensor`):
|
batch_size (`torch.Tensor`):
|
||||||
Batch size.
|
Batch size.
|
||||||
|
config (`Phi3Config`):
|
||||||
|
The model's configuration class
|
||||||
|
past_key_values (`Cache`):
|
||||||
|
The cache class that is being used currently to generate
|
||||||
"""
|
"""
|
||||||
if attention_mask is not None and attention_mask.dim() == 4:
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
@@ -1143,19 +1156,27 @@ class Phi3Model(Phi3PreTrainedModel):
|
|||||||
causal_mask = torch.full(
|
causal_mask = torch.full(
|
||||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
if sequence_length != 1:
|
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
if config.sliding_window is not None:
|
||||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||||
|
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||||
|
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||||
|
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||||
|
cache_position.reshape(-1, 1) - config.sliding_window
|
||||||
|
)
|
||||||
|
diagonal_attend_mask |= sliding_attend_mask
|
||||||
|
causal_mask *= diagonal_attend_mask
|
||||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
if attention_mask.shape[-1] > target_length:
|
||||||
|
attention_mask = attention_mask[:, :target_length]
|
||||||
mask_length = attention_mask.shape[-1]
|
mask_length = attention_mask.shape[-1]
|
||||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
padding_mask = padding_mask == 0
|
padding_mask = padding_mask == 0
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
padding_mask, min_dtype
|
padding_mask, min_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
@@ -1372,6 +1393,8 @@ class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin):
|
|||||||
device=device,
|
device=device,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
config=self.config,
|
||||||
|
past_key_values=past_key_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
if num_logits_to_keep is not None:
|
if num_logits_to_keep is not None:
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
|
from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
@@ -1203,7 +1203,7 @@ class PhimoeModel(PhimoePreTrainedModel):
|
|||||||
router_logits=all_router_logits,
|
router_logits=all_router_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
# Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask
|
||||||
def _update_causal_mask(
|
def _update_causal_mask(
|
||||||
self,
|
self,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
@@ -1222,21 +1222,30 @@ class PhimoeModel(PhimoePreTrainedModel):
|
|||||||
# to infer the attention mask.
|
# to infer the attention mask.
|
||||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
|
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
|
||||||
|
|
||||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||||
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
if (
|
||||||
|
self.config._attn_implementation == "sdpa"
|
||||||
|
and not (using_static_cache or using_sliding_window_cache)
|
||||||
|
and not output_attentions
|
||||||
|
):
|
||||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
inputs_embeds=input_tensor,
|
inputs_embeds=input_tensor,
|
||||||
past_key_values_length=past_seen_tokens,
|
past_key_values_length=past_seen_tokens,
|
||||||
|
sliding_window=self.config.sliding_window,
|
||||||
is_training=self.training,
|
is_training=self.training,
|
||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
dtype, device = input_tensor.dtype, input_tensor.device
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
sequence_length = input_tensor.shape[1]
|
sequence_length = input_tensor.shape[1]
|
||||||
if using_static_cache:
|
# SlidingWindowCache or StaticCache
|
||||||
|
if using_sliding_window_cache or using_static_cache:
|
||||||
target_length = past_key_values.get_max_cache_shape()
|
target_length = past_key_values.get_max_cache_shape()
|
||||||
|
# DynamicCache or no cache
|
||||||
else:
|
else:
|
||||||
target_length = (
|
target_length = (
|
||||||
attention_mask.shape[-1]
|
attention_mask.shape[-1]
|
||||||
@@ -1253,6 +1262,8 @@ class PhimoeModel(PhimoePreTrainedModel):
|
|||||||
device=device,
|
device=device,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
batch_size=input_tensor.shape[0],
|
batch_size=input_tensor.shape[0],
|
||||||
|
config=self.config,
|
||||||
|
past_key_values=past_key_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -1264,13 +1275,12 @@ class PhimoeModel(PhimoePreTrainedModel):
|
|||||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
min_dtype = torch.finfo(dtype).min
|
|
||||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
# Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Phimoe
|
||||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
sequence_length: int,
|
sequence_length: int,
|
||||||
@@ -1279,6 +1289,8 @@ class PhimoeModel(PhimoePreTrainedModel):
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
cache_position: torch.Tensor,
|
cache_position: torch.Tensor,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
config: PhimoeConfig,
|
||||||
|
past_key_values: Cache,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
@@ -1286,13 +1298,11 @@ class PhimoeModel(PhimoePreTrainedModel):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
attention_mask (`torch.Tensor`):
|
attention_mask (`torch.Tensor`):
|
||||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||||
`(batch_size, 1, query_length, key_value_length)`.
|
|
||||||
sequence_length (`int`):
|
sequence_length (`int`):
|
||||||
The sequence length being processed.
|
The sequence length being processed.
|
||||||
target_length (`int`):
|
target_length (`int`):
|
||||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
|
||||||
dtype (`torch.dtype`):
|
dtype (`torch.dtype`):
|
||||||
The dtype to use for the 4D attention mask.
|
The dtype to use for the 4D attention mask.
|
||||||
device (`torch.device`):
|
device (`torch.device`):
|
||||||
@@ -1301,6 +1311,10 @@ class PhimoeModel(PhimoePreTrainedModel):
|
|||||||
Indices depicting the position of the input sequence tokens in the sequence.
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
batch_size (`torch.Tensor`):
|
batch_size (`torch.Tensor`):
|
||||||
Batch size.
|
Batch size.
|
||||||
|
config (`PhimoeConfig`):
|
||||||
|
The model's configuration class
|
||||||
|
past_key_values (`Cache`):
|
||||||
|
The cache class that is being used currently to generate
|
||||||
"""
|
"""
|
||||||
if attention_mask is not None and attention_mask.dim() == 4:
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
@@ -1310,19 +1324,27 @@ class PhimoeModel(PhimoePreTrainedModel):
|
|||||||
causal_mask = torch.full(
|
causal_mask = torch.full(
|
||||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
if sequence_length != 1:
|
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
if config.sliding_window is not None:
|
||||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||||
|
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||||
|
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||||
|
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||||
|
cache_position.reshape(-1, 1) - config.sliding_window
|
||||||
|
)
|
||||||
|
diagonal_attend_mask |= sliding_attend_mask
|
||||||
|
causal_mask *= diagonal_attend_mask
|
||||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
if attention_mask.shape[-1] > target_length:
|
||||||
|
attention_mask = attention_mask[:, :target_length]
|
||||||
mask_length = attention_mask.shape[-1]
|
mask_length = attention_mask.shape[-1]
|
||||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
padding_mask = padding_mask == 0
|
padding_mask = padding_mask == 0
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
padding_mask, min_dtype
|
padding_mask, min_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
@@ -1561,6 +1583,8 @@ class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin):
|
|||||||
device=device,
|
device=device,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
config=self.config,
|
||||||
|
past_key_values=past_key_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
if num_logits_to_keep is not None:
|
if num_logits_to_keep is not None:
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
@@ -954,7 +954,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|||||||
attentions=all_self_attns,
|
attentions=all_self_attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
# Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask
|
||||||
def _update_causal_mask(
|
def _update_causal_mask(
|
||||||
self,
|
self,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
@@ -973,21 +973,30 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|||||||
# to infer the attention mask.
|
# to infer the attention mask.
|
||||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
|
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
|
||||||
|
|
||||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||||
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
if (
|
||||||
|
self.config._attn_implementation == "sdpa"
|
||||||
|
and not (using_static_cache or using_sliding_window_cache)
|
||||||
|
and not output_attentions
|
||||||
|
):
|
||||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
inputs_embeds=input_tensor,
|
inputs_embeds=input_tensor,
|
||||||
past_key_values_length=past_seen_tokens,
|
past_key_values_length=past_seen_tokens,
|
||||||
|
sliding_window=self.config.sliding_window,
|
||||||
is_training=self.training,
|
is_training=self.training,
|
||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
dtype, device = input_tensor.dtype, input_tensor.device
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
sequence_length = input_tensor.shape[1]
|
sequence_length = input_tensor.shape[1]
|
||||||
if using_static_cache:
|
# SlidingWindowCache or StaticCache
|
||||||
|
if using_sliding_window_cache or using_static_cache:
|
||||||
target_length = past_key_values.get_max_cache_shape()
|
target_length = past_key_values.get_max_cache_shape()
|
||||||
|
# DynamicCache or no cache
|
||||||
else:
|
else:
|
||||||
target_length = (
|
target_length = (
|
||||||
attention_mask.shape[-1]
|
attention_mask.shape[-1]
|
||||||
@@ -1004,6 +1013,8 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|||||||
device=device,
|
device=device,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
batch_size=input_tensor.shape[0],
|
batch_size=input_tensor.shape[0],
|
||||||
|
config=self.config,
|
||||||
|
past_key_values=past_key_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -1015,13 +1026,12 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|||||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
min_dtype = torch.finfo(dtype).min
|
|
||||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
# Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Qwen2
|
||||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
sequence_length: int,
|
sequence_length: int,
|
||||||
@@ -1030,6 +1040,8 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
cache_position: torch.Tensor,
|
cache_position: torch.Tensor,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
config: Qwen2Config,
|
||||||
|
past_key_values: Cache,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
@@ -1037,13 +1049,11 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
attention_mask (`torch.Tensor`):
|
attention_mask (`torch.Tensor`):
|
||||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||||
`(batch_size, 1, query_length, key_value_length)`.
|
|
||||||
sequence_length (`int`):
|
sequence_length (`int`):
|
||||||
The sequence length being processed.
|
The sequence length being processed.
|
||||||
target_length (`int`):
|
target_length (`int`):
|
||||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
|
||||||
dtype (`torch.dtype`):
|
dtype (`torch.dtype`):
|
||||||
The dtype to use for the 4D attention mask.
|
The dtype to use for the 4D attention mask.
|
||||||
device (`torch.device`):
|
device (`torch.device`):
|
||||||
@@ -1052,6 +1062,10 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|||||||
Indices depicting the position of the input sequence tokens in the sequence.
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
batch_size (`torch.Tensor`):
|
batch_size (`torch.Tensor`):
|
||||||
Batch size.
|
Batch size.
|
||||||
|
config (`Qwen2Config`):
|
||||||
|
The model's configuration class
|
||||||
|
past_key_values (`Cache`):
|
||||||
|
The cache class that is being used currently to generate
|
||||||
"""
|
"""
|
||||||
if attention_mask is not None and attention_mask.dim() == 4:
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
@@ -1061,19 +1075,27 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|||||||
causal_mask = torch.full(
|
causal_mask = torch.full(
|
||||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
if sequence_length != 1:
|
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
if config.sliding_window is not None:
|
||||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||||
|
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||||
|
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||||
|
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||||
|
cache_position.reshape(-1, 1) - config.sliding_window
|
||||||
|
)
|
||||||
|
diagonal_attend_mask |= sliding_attend_mask
|
||||||
|
causal_mask *= diagonal_attend_mask
|
||||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
if attention_mask.shape[-1] > target_length:
|
||||||
|
attention_mask = attention_mask[:, :target_length]
|
||||||
mask_length = attention_mask.shape[-1]
|
mask_length = attention_mask.shape[-1]
|
||||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
padding_mask = padding_mask == 0
|
padding_mask = padding_mask == 0
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
padding_mask, min_dtype
|
padding_mask, min_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
@@ -1206,6 +1228,79 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
|
|||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation
|
||||||
|
def prepare_inputs_for_generation(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
past_key_values=None,
|
||||||
|
attention_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
cache_position=None,
|
||||||
|
position_ids=None,
|
||||||
|
use_cache=True,
|
||||||
|
num_logits_to_keep=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||||
|
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||||
|
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||||
|
if past_key_values is not None:
|
||||||
|
if inputs_embeds is not None: # Exception 1
|
||||||
|
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||||
|
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||||
|
input_ids = input_ids[:, cache_position]
|
||||||
|
|
||||||
|
if attention_mask is not None and position_ids is None:
|
||||||
|
# create position_ids on the fly for batch generation
|
||||||
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
if past_key_values:
|
||||||
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||||
|
|
||||||
|
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
||||||
|
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||||
|
if inputs_embeds is not None and cache_position[0] == 0:
|
||||||
|
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||||
|
else:
|
||||||
|
# `contiguous()` needed for compilation use cases
|
||||||
|
model_inputs = {"input_ids": input_ids.contiguous(), "inputs_embeds": None}
|
||||||
|
|
||||||
|
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||||
|
if model_inputs["inputs_embeds"] is not None:
|
||||||
|
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||||
|
device = model_inputs["inputs_embeds"].device
|
||||||
|
else:
|
||||||
|
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||||
|
device = model_inputs["input_ids"].device
|
||||||
|
|
||||||
|
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
target_length=past_key_values.get_max_cache_shape(),
|
||||||
|
dtype=self.lm_head.weight.dtype,
|
||||||
|
device=device,
|
||||||
|
cache_position=cache_position,
|
||||||
|
batch_size=batch_size,
|
||||||
|
config=self.config,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_logits_to_keep is not None:
|
||||||
|
model_inputs["num_logits_to_keep"] = num_logits_to_keep
|
||||||
|
|
||||||
|
model_inputs.update(
|
||||||
|
{
|
||||||
|
"position_ids": position_ids,
|
||||||
|
"cache_position": cache_position,
|
||||||
|
"past_key_values": past_key_values,
|
||||||
|
"use_cache": use_cache,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
@@ -1135,7 +1135,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
|
|||||||
router_logits=all_router_logits,
|
router_logits=all_router_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
# Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask
|
||||||
def _update_causal_mask(
|
def _update_causal_mask(
|
||||||
self,
|
self,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
@@ -1154,21 +1154,30 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
|
|||||||
# to infer the attention mask.
|
# to infer the attention mask.
|
||||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
|
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
|
||||||
|
|
||||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||||
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
if (
|
||||||
|
self.config._attn_implementation == "sdpa"
|
||||||
|
and not (using_static_cache or using_sliding_window_cache)
|
||||||
|
and not output_attentions
|
||||||
|
):
|
||||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
inputs_embeds=input_tensor,
|
inputs_embeds=input_tensor,
|
||||||
past_key_values_length=past_seen_tokens,
|
past_key_values_length=past_seen_tokens,
|
||||||
|
sliding_window=self.config.sliding_window,
|
||||||
is_training=self.training,
|
is_training=self.training,
|
||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
dtype, device = input_tensor.dtype, input_tensor.device
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
sequence_length = input_tensor.shape[1]
|
sequence_length = input_tensor.shape[1]
|
||||||
if using_static_cache:
|
# SlidingWindowCache or StaticCache
|
||||||
|
if using_sliding_window_cache or using_static_cache:
|
||||||
target_length = past_key_values.get_max_cache_shape()
|
target_length = past_key_values.get_max_cache_shape()
|
||||||
|
# DynamicCache or no cache
|
||||||
else:
|
else:
|
||||||
target_length = (
|
target_length = (
|
||||||
attention_mask.shape[-1]
|
attention_mask.shape[-1]
|
||||||
@@ -1185,6 +1194,8 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
|
|||||||
device=device,
|
device=device,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
batch_size=input_tensor.shape[0],
|
batch_size=input_tensor.shape[0],
|
||||||
|
config=self.config,
|
||||||
|
past_key_values=past_key_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -1196,13 +1207,12 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
|
|||||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
min_dtype = torch.finfo(dtype).min
|
|
||||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
# Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Qwen2Moe
|
||||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
sequence_length: int,
|
sequence_length: int,
|
||||||
@@ -1211,6 +1221,8 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
cache_position: torch.Tensor,
|
cache_position: torch.Tensor,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
config: Qwen2MoeConfig,
|
||||||
|
past_key_values: Cache,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
@@ -1218,13 +1230,11 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
attention_mask (`torch.Tensor`):
|
attention_mask (`torch.Tensor`):
|
||||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||||
`(batch_size, 1, query_length, key_value_length)`.
|
|
||||||
sequence_length (`int`):
|
sequence_length (`int`):
|
||||||
The sequence length being processed.
|
The sequence length being processed.
|
||||||
target_length (`int`):
|
target_length (`int`):
|
||||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
|
||||||
dtype (`torch.dtype`):
|
dtype (`torch.dtype`):
|
||||||
The dtype to use for the 4D attention mask.
|
The dtype to use for the 4D attention mask.
|
||||||
device (`torch.device`):
|
device (`torch.device`):
|
||||||
@@ -1233,6 +1243,10 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
|
|||||||
Indices depicting the position of the input sequence tokens in the sequence.
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
batch_size (`torch.Tensor`):
|
batch_size (`torch.Tensor`):
|
||||||
Batch size.
|
Batch size.
|
||||||
|
config (`Qwen2MoeConfig`):
|
||||||
|
The model's configuration class
|
||||||
|
past_key_values (`Cache`):
|
||||||
|
The cache class that is being used currently to generate
|
||||||
"""
|
"""
|
||||||
if attention_mask is not None and attention_mask.dim() == 4:
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
@@ -1242,19 +1256,27 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
|
|||||||
causal_mask = torch.full(
|
causal_mask = torch.full(
|
||||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
if sequence_length != 1:
|
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
if config.sliding_window is not None:
|
||||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||||
|
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||||
|
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||||
|
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||||
|
cache_position.reshape(-1, 1) - config.sliding_window
|
||||||
|
)
|
||||||
|
diagonal_attend_mask |= sliding_attend_mask
|
||||||
|
causal_mask *= diagonal_attend_mask
|
||||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
if attention_mask.shape[-1] > target_length:
|
||||||
|
attention_mask = attention_mask[:, :target_length]
|
||||||
mask_length = attention_mask.shape[-1]
|
mask_length = attention_mask.shape[-1]
|
||||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
padding_mask = padding_mask == 0
|
padding_mask = padding_mask == 0
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
padding_mask, min_dtype
|
padding_mask, min_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
@@ -1410,6 +1432,79 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin):
|
|||||||
router_logits=outputs.router_logits,
|
router_logits=outputs.router_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation
|
||||||
|
def prepare_inputs_for_generation(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
past_key_values=None,
|
||||||
|
attention_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
cache_position=None,
|
||||||
|
position_ids=None,
|
||||||
|
use_cache=True,
|
||||||
|
num_logits_to_keep=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||||
|
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||||
|
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||||
|
if past_key_values is not None:
|
||||||
|
if inputs_embeds is not None: # Exception 1
|
||||||
|
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||||
|
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||||
|
input_ids = input_ids[:, cache_position]
|
||||||
|
|
||||||
|
if attention_mask is not None and position_ids is None:
|
||||||
|
# create position_ids on the fly for batch generation
|
||||||
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
if past_key_values:
|
||||||
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||||
|
|
||||||
|
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
||||||
|
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||||
|
if inputs_embeds is not None and cache_position[0] == 0:
|
||||||
|
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||||
|
else:
|
||||||
|
# `contiguous()` needed for compilation use cases
|
||||||
|
model_inputs = {"input_ids": input_ids.contiguous(), "inputs_embeds": None}
|
||||||
|
|
||||||
|
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||||
|
if model_inputs["inputs_embeds"] is not None:
|
||||||
|
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||||
|
device = model_inputs["inputs_embeds"].device
|
||||||
|
else:
|
||||||
|
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||||
|
device = model_inputs["input_ids"].device
|
||||||
|
|
||||||
|
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
target_length=past_key_values.get_max_cache_shape(),
|
||||||
|
dtype=self.lm_head.weight.dtype,
|
||||||
|
device=device,
|
||||||
|
cache_position=cache_position,
|
||||||
|
batch_size=batch_size,
|
||||||
|
config=self.config,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_logits_to_keep is not None:
|
||||||
|
model_inputs["num_logits_to_keep"] = num_logits_to_keep
|
||||||
|
|
||||||
|
model_inputs.update(
|
||||||
|
{
|
||||||
|
"position_ids": position_ids,
|
||||||
|
"cache_position": cache_position,
|
||||||
|
"past_key_values": past_key_values,
|
||||||
|
"use_cache": use_cache,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ import torch.utils.checkpoint
|
|||||||
from torch.nn import CrossEntropyLoss, LayerNorm
|
from torch.nn import CrossEntropyLoss, LayerNorm
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, StaticCache
|
from ...cache_utils import Cache, SlidingWindowCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_attn_mask_utils import (
|
from ...modeling_attn_mask_utils import (
|
||||||
AttentionMaskConverter,
|
AttentionMaskConverter,
|
||||||
@@ -1217,7 +1217,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
|||||||
attentions=all_self_attns,
|
attentions=all_self_attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
# Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask
|
||||||
def _update_causal_mask(
|
def _update_causal_mask(
|
||||||
self,
|
self,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
@@ -1236,21 +1236,30 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
|||||||
# to infer the attention mask.
|
# to infer the attention mask.
|
||||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
|
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
|
||||||
|
|
||||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||||
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
if (
|
||||||
|
self.config._attn_implementation == "sdpa"
|
||||||
|
and not (using_static_cache or using_sliding_window_cache)
|
||||||
|
and not output_attentions
|
||||||
|
):
|
||||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
inputs_embeds=input_tensor,
|
inputs_embeds=input_tensor,
|
||||||
past_key_values_length=past_seen_tokens,
|
past_key_values_length=past_seen_tokens,
|
||||||
|
sliding_window=self.config.sliding_window,
|
||||||
is_training=self.training,
|
is_training=self.training,
|
||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
dtype, device = input_tensor.dtype, input_tensor.device
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
sequence_length = input_tensor.shape[1]
|
sequence_length = input_tensor.shape[1]
|
||||||
if using_static_cache:
|
# SlidingWindowCache or StaticCache
|
||||||
|
if using_sliding_window_cache or using_static_cache:
|
||||||
target_length = past_key_values.get_max_cache_shape()
|
target_length = past_key_values.get_max_cache_shape()
|
||||||
|
# DynamicCache or no cache
|
||||||
else:
|
else:
|
||||||
target_length = (
|
target_length = (
|
||||||
attention_mask.shape[-1]
|
attention_mask.shape[-1]
|
||||||
@@ -1267,6 +1276,8 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
|||||||
device=device,
|
device=device,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
batch_size=input_tensor.shape[0],
|
batch_size=input_tensor.shape[0],
|
||||||
|
config=self.config,
|
||||||
|
past_key_values=past_key_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -1278,13 +1289,12 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
|||||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
min_dtype = torch.finfo(dtype).min
|
|
||||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
# Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Qwen2VL
|
||||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
sequence_length: int,
|
sequence_length: int,
|
||||||
@@ -1293,6 +1303,8 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
cache_position: torch.Tensor,
|
cache_position: torch.Tensor,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
config: Qwen2VLConfig,
|
||||||
|
past_key_values: Cache,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
@@ -1300,13 +1312,11 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
attention_mask (`torch.Tensor`):
|
attention_mask (`torch.Tensor`):
|
||||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||||
`(batch_size, 1, query_length, key_value_length)`.
|
|
||||||
sequence_length (`int`):
|
sequence_length (`int`):
|
||||||
The sequence length being processed.
|
The sequence length being processed.
|
||||||
target_length (`int`):
|
target_length (`int`):
|
||||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
|
||||||
dtype (`torch.dtype`):
|
dtype (`torch.dtype`):
|
||||||
The dtype to use for the 4D attention mask.
|
The dtype to use for the 4D attention mask.
|
||||||
device (`torch.device`):
|
device (`torch.device`):
|
||||||
@@ -1315,6 +1325,10 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
|||||||
Indices depicting the position of the input sequence tokens in the sequence.
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
batch_size (`torch.Tensor`):
|
batch_size (`torch.Tensor`):
|
||||||
Batch size.
|
Batch size.
|
||||||
|
config (`Qwen2VLConfig`):
|
||||||
|
The model's configuration class
|
||||||
|
past_key_values (`Cache`):
|
||||||
|
The cache class that is being used currently to generate
|
||||||
"""
|
"""
|
||||||
if attention_mask is not None and attention_mask.dim() == 4:
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
@@ -1324,19 +1338,27 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
|||||||
causal_mask = torch.full(
|
causal_mask = torch.full(
|
||||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
if sequence_length != 1:
|
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
if config.sliding_window is not None:
|
||||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||||
|
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||||
|
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||||
|
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||||
|
cache_position.reshape(-1, 1) - config.sliding_window
|
||||||
|
)
|
||||||
|
diagonal_attend_mask |= sliding_attend_mask
|
||||||
|
causal_mask *= diagonal_attend_mask
|
||||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
if attention_mask.shape[-1] > target_length:
|
||||||
|
attention_mask = attention_mask[:, :target_length]
|
||||||
mask_length = attention_mask.shape[-1]
|
mask_length = attention_mask.shape[-1]
|
||||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
padding_mask = padding_mask == 0
|
padding_mask = padding_mask == 0
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
padding_mask, min_dtype
|
padding_mask, min_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
@@ -1820,6 +1842,8 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
|||||||
device=device,
|
device=device,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
config=self.config,
|
||||||
|
past_key_values=past_key_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_inputs.update(
|
model_inputs.update(
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
@@ -928,7 +928,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
|
|||||||
attentions=all_self_attns,
|
attentions=all_self_attns,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
# Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask
|
||||||
def _update_causal_mask(
|
def _update_causal_mask(
|
||||||
self,
|
self,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
@@ -947,21 +947,30 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
|
|||||||
# to infer the attention mask.
|
# to infer the attention mask.
|
||||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
|
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
|
||||||
|
|
||||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||||
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
if (
|
||||||
|
self.config._attn_implementation == "sdpa"
|
||||||
|
and not (using_static_cache or using_sliding_window_cache)
|
||||||
|
and not output_attentions
|
||||||
|
):
|
||||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
attention_mask,
|
attention_mask,
|
||||||
inputs_embeds=input_tensor,
|
inputs_embeds=input_tensor,
|
||||||
past_key_values_length=past_seen_tokens,
|
past_key_values_length=past_seen_tokens,
|
||||||
|
sliding_window=self.config.sliding_window,
|
||||||
is_training=self.training,
|
is_training=self.training,
|
||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
dtype, device = input_tensor.dtype, input_tensor.device
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
sequence_length = input_tensor.shape[1]
|
sequence_length = input_tensor.shape[1]
|
||||||
if using_static_cache:
|
# SlidingWindowCache or StaticCache
|
||||||
|
if using_sliding_window_cache or using_static_cache:
|
||||||
target_length = past_key_values.get_max_cache_shape()
|
target_length = past_key_values.get_max_cache_shape()
|
||||||
|
# DynamicCache or no cache
|
||||||
else:
|
else:
|
||||||
target_length = (
|
target_length = (
|
||||||
attention_mask.shape[-1]
|
attention_mask.shape[-1]
|
||||||
@@ -978,6 +987,8 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
|
|||||||
device=device,
|
device=device,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
batch_size=input_tensor.shape[0],
|
batch_size=input_tensor.shape[0],
|
||||||
|
config=self.config,
|
||||||
|
past_key_values=past_key_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -989,13 +1000,12 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
|
|||||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
min_dtype = torch.finfo(dtype).min
|
|
||||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
# Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Starcoder2
|
||||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
sequence_length: int,
|
sequence_length: int,
|
||||||
@@ -1004,6 +1014,8 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
cache_position: torch.Tensor,
|
cache_position: torch.Tensor,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
config: Starcoder2Config,
|
||||||
|
past_key_values: Cache,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
@@ -1011,13 +1023,11 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
attention_mask (`torch.Tensor`):
|
attention_mask (`torch.Tensor`):
|
||||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
|
||||||
`(batch_size, 1, query_length, key_value_length)`.
|
|
||||||
sequence_length (`int`):
|
sequence_length (`int`):
|
||||||
The sequence length being processed.
|
The sequence length being processed.
|
||||||
target_length (`int`):
|
target_length (`int`):
|
||||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
|
||||||
dtype (`torch.dtype`):
|
dtype (`torch.dtype`):
|
||||||
The dtype to use for the 4D attention mask.
|
The dtype to use for the 4D attention mask.
|
||||||
device (`torch.device`):
|
device (`torch.device`):
|
||||||
@@ -1026,6 +1036,10 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
|
|||||||
Indices depicting the position of the input sequence tokens in the sequence.
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
batch_size (`torch.Tensor`):
|
batch_size (`torch.Tensor`):
|
||||||
Batch size.
|
Batch size.
|
||||||
|
config (`Starcoder2Config`):
|
||||||
|
The model's configuration class
|
||||||
|
past_key_values (`Cache`):
|
||||||
|
The cache class that is being used currently to generate
|
||||||
"""
|
"""
|
||||||
if attention_mask is not None and attention_mask.dim() == 4:
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
@@ -1035,19 +1049,27 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
|
|||||||
causal_mask = torch.full(
|
causal_mask = torch.full(
|
||||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
if sequence_length != 1:
|
diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
if config.sliding_window is not None:
|
||||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
# if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
|
||||||
|
# the check is needed to verify is current checkpoint was trained with sliding window or not
|
||||||
|
if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
|
||||||
|
sliding_attend_mask = torch.arange(target_length, device=device) <= (
|
||||||
|
cache_position.reshape(-1, 1) - config.sliding_window
|
||||||
|
)
|
||||||
|
diagonal_attend_mask |= sliding_attend_mask
|
||||||
|
causal_mask *= diagonal_attend_mask
|
||||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
if attention_mask.shape[-1] > target_length:
|
||||||
|
attention_mask = attention_mask[:, :target_length]
|
||||||
mask_length = attention_mask.shape[-1]
|
mask_length = attention_mask.shape[-1]
|
||||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
padding_mask = padding_mask == 0
|
padding_mask = padding_mask == 0
|
||||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
padding_mask, min_dtype
|
padding_mask, min_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
@@ -1182,6 +1204,79 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin):
|
|||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation
|
||||||
|
def prepare_inputs_for_generation(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
past_key_values=None,
|
||||||
|
attention_mask=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
cache_position=None,
|
||||||
|
position_ids=None,
|
||||||
|
use_cache=True,
|
||||||
|
num_logits_to_keep=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
||||||
|
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
||||||
|
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
||||||
|
if past_key_values is not None:
|
||||||
|
if inputs_embeds is not None: # Exception 1
|
||||||
|
input_ids = input_ids[:, -cache_position.shape[0] :]
|
||||||
|
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
||||||
|
input_ids = input_ids[:, cache_position]
|
||||||
|
|
||||||
|
if attention_mask is not None and position_ids is None:
|
||||||
|
# create position_ids on the fly for batch generation
|
||||||
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
if past_key_values:
|
||||||
|
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||||
|
|
||||||
|
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
|
||||||
|
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||||
|
if inputs_embeds is not None and cache_position[0] == 0:
|
||||||
|
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
||||||
|
else:
|
||||||
|
# `contiguous()` needed for compilation use cases
|
||||||
|
model_inputs = {"input_ids": input_ids.contiguous(), "inputs_embeds": None}
|
||||||
|
|
||||||
|
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
|
||||||
|
if model_inputs["inputs_embeds"] is not None:
|
||||||
|
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
|
||||||
|
device = model_inputs["inputs_embeds"].device
|
||||||
|
else:
|
||||||
|
batch_size, sequence_length = model_inputs["input_ids"].shape
|
||||||
|
device = model_inputs["input_ids"].device
|
||||||
|
|
||||||
|
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
target_length=past_key_values.get_max_cache_shape(),
|
||||||
|
dtype=self.lm_head.weight.dtype,
|
||||||
|
device=device,
|
||||||
|
cache_position=cache_position,
|
||||||
|
batch_size=batch_size,
|
||||||
|
config=self.config,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_logits_to_keep is not None:
|
||||||
|
model_inputs["num_logits_to_keep"] = num_logits_to_keep
|
||||||
|
|
||||||
|
model_inputs.update(
|
||||||
|
{
|
||||||
|
"position_ids": position_ids,
|
||||||
|
"cache_position": cache_position,
|
||||||
|
"past_key_values": past_key_values,
|
||||||
|
"use_cache": use_cache,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -615,3 +615,93 @@ class Phi3IntegrationTest(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
self.assertListEqual(output_text, EXPECTED_OUTPUT)
|
self.assertListEqual(output_text, EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
def test_phi3_mini_4k_sliding_window(self):
|
||||||
|
"""
|
||||||
|
This tests that Phi3 doesn't deteriorate in quality for long context generations. Since Phi3 has
|
||||||
|
sliding window attention, the test is tailored so that (context + max_new_tokens > sliding_window).
|
||||||
|
See #33586 for more
|
||||||
|
"""
|
||||||
|
model = Phi3ForCausalLM.from_pretrained(
|
||||||
|
"microsoft/Phi-3-mini-4k-instruct", device_map=torch_device, torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
|
||||||
|
|
||||||
|
input_text = """
|
||||||
|
<|user|>
|
||||||
|
Tell me about Paris, France.<|end|>
|
||||||
|
<|assistant|>
|
||||||
|
Paris, the capital city of France, is renowned for its rich history, iconic landmarks, and vibrant culture. Known as "The City of Light," Paris is situated in the north-central part of the country along the Seine River.
|
||||||
|
|
||||||
|
Here are some key aspects of Paris:
|
||||||
|
|
||||||
|
1. Landmarks: Paris is home to numerous famous landmarks, including the Eiffel Tower, the Louvre Museum, Notre-Dame Cathedral, and the Champs-Élysées. The Eiffel Tower, built in 1889, is an iconic symbol of Paris and attracts millions of tourists each year. The Louvre Museum, the world's largest art museum, houses thousands of works of art, including the Mona Lisa and the Venus de Milo.
|
||||||
|
|
||||||
|
2. History: Paris has a rich history dating back to the 3rd century BC, when it was founded by a Celtic tribe called the Parisii. Over the centuries, the city has been influenced by various cultures, including the Romans, the Franks, and the Normans. The French Revolution in the late 18th century marked a significant turning point in Paris's history, leading to the establishment of the modern French Republic.
|
||||||
|
|
||||||
|
3. Culture: Paris is a global center for art, fashion, gastronomy, and culture. The city is home to numerous museums, including the Centre Pompidou, Musée d'Orsay, and Musée Rodin. Paris is also known for its fashion industry, with many famous designers having their origins in the city. The city's cuisine is also highly regarded, with a focus on fresh ingredients, and a wide variety of dishes, including French classics like coq au vin, boeuf bourguignon, and crêpes.
|
||||||
|
|
||||||
|
4. Architecture: Parisian architecture is characterized by its diverse styles, ranging from Gothic and Romanesque to Art Nouveau and Art Deco. The city's famous Haussmannian buildings, designed by Baron Haussmann in the mid-19th century, are known for their uniform facades, wrought-iron balconies, and large windows.
|
||||||
|
|
||||||
|
5. Transportation: Paris has an extensive public transportation system, including the Paris Métro, RER (suburban trains), and buses. The city's iconic yellow taxis are also a popular mode of transportation.
|
||||||
|
|
||||||
|
6. Language: The official language of Paris is French, and the city's residents are known for their charm and politeness.
|
||||||
|
|
||||||
|
7. Festivals and Events: Paris hosts numerous festivals and events throughout the year, including the annual Bastille Day celebrations, the Paris Fashion Week, and the famous annual New Year's Eve fireworks on the Eiffel Tower.
|
||||||
|
|
||||||
|
8. Geography: Paris is located in the north-central part of France, with the Seine River running through the city. The city's geography is characterized by rolling hills and picturesque parks, such as the Bois de Boulogne and the Jardin des Tuileries.
|
||||||
|
|
||||||
|
9. Population: As of 2021, Paris has an estimated population of around 2.2 million residents, with the metropolitan area housing over 12 million people.
|
||||||
|
|
||||||
|
In summary, Paris is a city steeped in history, culture, and art, with a unique blend of architectural styles and a vibrant atmosphere that continues to captivate millions of visitors each year.<|end|>
|
||||||
|
<|user|>
|
||||||
|
Please give me a list of 5 architectural landmarks in Paris, France.<|end|>
|
||||||
|
<|assistant|>
|
||||||
|
1. Eiffel Tower: Designed by Gustave Eiffel and completed in 1889, the Eiffel Tower is an iconic symbol of Paris and France. Standing at 324 meters tall, it was the tallest man-made structure in the world until the completion of the Chrysler Building in New York in 1930. The Eiffel Tower is made of wrought iron and offers visitors stunning views of the city from its three levels.
|
||||||
|
|
||||||
|
2. Notre-Dame Cathedral: Located on the Île de la Cité, Notre-Dame Cathedral is a masterpiece of French Gothic architecture. Construction began in the 12th century and continued for over 200 years, with the cathedral's completion in the 14th century. The cathedral is famous for its intricate facade, stained-glass windows, and the iconic gargoyles and chimeras.
|
||||||
|
|
||||||
|
3. Louvre Museum: Originally built as a fortress in the 12th century, the Louvre Museum is now the world's largest art museum and a historic monument in Paris. The museum's most famous landmark is the iconic glass pyramid entrance, designed by architect I. M. Pei in the 1980s. The Louvre houses over 380,000 works of art, including the Mona Lisa and the Venus de Milo.
|
||||||
|
|
||||||
|
4. Sacré-Cœur Basilica: The Sacré-Cœur Basilica, also known as the Basilique du Sacré-Cœur, is a Roman Catholic church and minor basilica located at the summit of the butte Montmartre, the highest point in Paris. The basilica was designed by Paul Abadie and dedicated in 1914. Its white domes and lavender-colored travertine stone make it a distinctive landmark in the Paris skyline.
|
||||||
|
|
||||||
|
5. Arc de Triomphe: The Arc de Triomphe is a monumental structure located at the western end of the Champs-Élysées. Commissioned by Napoleon in 1806, the Arc was designed by Jean-François-Thérèse Chalgrin and completed in 1836. The monument honors those who fought and died for France during the French Revolutionary and Napoleonic Wars. The Arc features sculptural reliefs and inscriptions, and its façade is adorned with the names of 357 generals and 660 soldiers.
|
||||||
|
|
||||||
|
These five architectural landmarks showcase the diverse styles and historical periods of Paris, from Gothic to Neoclassical, and from the 19th to the 20th centuries. Each landmark has its unique features and contributes to the city's rich architectural heritage.<|end|>
|
||||||
|
<|user|>
|
||||||
|
Please give me a list of 10 famous items displayed in the Louvre Museum. Thanks!<|end|>
|
||||||
|
<|assistant|>
|
||||||
|
1. Mona Lisa: The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is arguably the most famous painting in the world. The portrait is known for its enigmatic smile and masterful use of sfumato, a technique that creates a soft, hazy effect.
|
||||||
|
|
||||||
|
2. Venus de Milo: This ancient Greek statue, believed to have been created around 130-100 BC, is a masterpiece of Hellenistic sculpture. The Venus de Milo is renowned for its graceful beauty and the mystery surrounding its missing arms.
|
||||||
|
|
||||||
|
3. Winged Victory of Samothrace: This Hellenistic sculpture, dating back to the 2nd century BC, depicts the Greek goddess Nike, the personification of victory. The sculpture is celebrated for its dynamic movement and intricate details.
|
||||||
|
|
||||||
|
4. Liberty Leading the People: This iconic painting by Eugène Delacroix, created in 1830, commemorates the July Revolution in France. The artwork depicts a woman personifying Liberty leading a group of revolutionaries over the bodies of the fallen.
|
||||||
|
|
||||||
|
5. The Wedding at Cana: A 1516 painting by Veronese, The Wedding at Cana is a large-scale work that depicts the biblical story of Jesus turning water into wine at a wedding feast. The painting is known for its vibrant colors and intricate details.
|
||||||
|
|
||||||
|
6. The Raft of the Medusa: This 1819 painting by Théodore Géricault is a powerful depiction of the aftermath of the shipwreck of the French frigate Méduse. The painting is famous for its dramatic composition and emotional intensity.
|
||||||
|
|
||||||
|
7. The Coronation of Napoleon: This 1805 painting by Jacques-Louis David portrays the coronation of Napoleon Bonaparte as Emperor of the French. The artwork is a masterpiece of neoclassical style and captures the pomp and ceremony of the event.
|
||||||
|
|
||||||
|
8. The Death of Socrates: A 1914 painting by Pablo Picasso, The Death of Socrates depicts the moment when the ancient Greek philosopher Socrates drinks the poison hemlock, as ordered by the Athenian government. The painting is a powerful expression of the existential themes associated with Socrates' death.
|
||||||
|
|
||||||
|
9. The Turkish Bath: A 1887 painting by Puvis de Chavannes, The Turkish Bath is a large-scale work that depicts a scene of relaxation and leisity in a traditional Turkish bath. The painting is known for its subdued color palette and serene atmosphere.
|
||||||
|
|
||||||
|
10. The Wedding at Cana: A 1901 painting by Giovanni Boldini, The Wedding at Cana is a lively depiction of a festive Italian wedding. The painting is characterized by Boldini's signature flamboyant style and his ability to capture the essence of the subjects with a sense of movement and energy.
|
||||||
|
|
||||||
|
These 10 famous items in the Louvre Museum showcase a diverse range of artistic styles and periods, from ancient Greek and Roman sculptures to Renaissance and 19th-century paintings. The collection at the Louvre is one of the world's largest and most important museums, housing over 800,000 works of art, and these famous items are just a small selection of the museum's vast collection.<|end|>
|
||||||
|
<|user|>
|
||||||
|
Please give me a list of 10 famous French dishes.<|end|>
|
||||||
|
<|assistant|>
|
||||||
|
"""
|
||||||
|
|
||||||
|
inputs = tokenizer(input_text, return_tensors="pt").to(device=torch_device)
|
||||||
|
outputs = model.generate(**inputs, max_new_tokens=100)
|
||||||
|
output_text = tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[1] :], skip_special_tokens=True)
|
||||||
|
EXPECTED_OUTPUT = [
|
||||||
|
'1. Coq au Vin: Coq au Vin is a classic French dish that translates to "rooster in wine." The dish consists of chicken braised with wine, lardons, mushrooms, and garlic. It is a hearty and flavorful dish that is often served with potatoes or rice.\n\n 2. Boeuf Bourguignon: Boeuf Bourguignon is a traditional French beef stew that'
|
||||||
|
]
|
||||||
|
|
||||||
|
self.assertListEqual(output_text, EXPECTED_OUTPUT)
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ class Qwen2ModelTester:
|
|||||||
num_hidden_layers=5,
|
num_hidden_layers=5,
|
||||||
max_window_layers=3,
|
max_window_layers=3,
|
||||||
use_sliding_window=True,
|
use_sliding_window=True,
|
||||||
sliding_window=2,
|
sliding_window=50,
|
||||||
num_attention_heads=4,
|
num_attention_heads=4,
|
||||||
num_key_value_heads=2,
|
num_key_value_heads=2,
|
||||||
intermediate_size=37,
|
intermediate_size=37,
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ class Qwen2MoeModelTester:
|
|||||||
num_hidden_layers=5,
|
num_hidden_layers=5,
|
||||||
max_window_layers=3,
|
max_window_layers=3,
|
||||||
use_sliding_window=True,
|
use_sliding_window=True,
|
||||||
sliding_window=2,
|
sliding_window=50,
|
||||||
num_attention_heads=4,
|
num_attention_heads=4,
|
||||||
num_key_value_heads=2,
|
num_key_value_heads=2,
|
||||||
intermediate_size=37,
|
intermediate_size=37,
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import importlib
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from typing import Dict, List, Set
|
from typing import Dict, List, Optional, Set
|
||||||
|
|
||||||
import libcst as cst
|
import libcst as cst
|
||||||
from check_copies import run_ruff
|
from check_copies import run_ruff
|
||||||
@@ -623,7 +623,7 @@ def get_new_part(class_name, base_class):
|
|||||||
return snake_case
|
return snake_case
|
||||||
|
|
||||||
|
|
||||||
def find_all_dependencies(function: str, dependency_mapping: dict[str, set]):
|
def find_all_dependencies(function: str, dependency_mapping: Dict[str, set]):
|
||||||
"""Return all the dependencies of the given top-level function. Given the following structure in the `modular_xxx.py` file:
|
"""Return all the dependencies of the given top-level function. Given the following structure in the `modular_xxx.py` file:
|
||||||
```
|
```
|
||||||
def foo1():
|
def foo1():
|
||||||
@@ -1001,8 +1001,8 @@ class ModularConverterTransformer(CSTTransformer):
|
|||||||
top_level_function: str,
|
top_level_function: str,
|
||||||
body: dict,
|
body: dict,
|
||||||
function_node: cst.FunctionDef,
|
function_node: cst.FunctionDef,
|
||||||
matching_callers: set | None = None,
|
matching_callers: Optional[set] = None,
|
||||||
parent: str | None = None,
|
parent: Optional[str] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if the `top_level_function` should be added to the body (i.e. it is not already present, and `matching_callers`
|
"""Check if the `top_level_function` should be added to the body (i.e. it is not already present, and `matching_callers`
|
||||||
is not empy, or `parent`is provided). If it should be added, do it (in the correct location, just before its caller) and return
|
is not empy, or `parent`is provided). If it should be added, do it (in the correct location, just before its caller) and return
|
||||||
|
|||||||
Reference in New Issue
Block a user