>3-5x faster torch.compile forward compilation for autoregressive decoder models (#32227)

* draft

* apply changes to all relevant archs

* rerun ci - check_docstrings.py failing?

* fix docstring

* move 2D->4D mask creation to modeling file

* repo consistency

* fix the batch size = 1 case - calling contiguous is not enough

* nit

* style

* propagate to gemma/gemma-2

* prepare inputs for gemma generation

* implement test and tiny fix in gemma2

* Update src/transformers/models/bloom/modeling_bloom.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix copies

* ci pass

* fix gemma's test_compile_static_cache tests

* flacky

* retrigger ci

---------

Co-authored-by: sanchit-gandhi <sanchit@huggingface.co>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
fxmarty
2024-07-31 20:03:07 +02:00
committed by GitHub
parent b46bd8b9d2
commit 92abe60334
22 changed files with 1601 additions and 423 deletions

View File

@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy import copy
import inspect import inspect
import warnings import warnings

View File

@@ -45,6 +45,60 @@ _CHECKPOINT_FOR_DOC = "bigscience/bloom-560m"
_CONFIG_FOR_DOC = "BloomConfig" _CONFIG_FOR_DOC = "BloomConfig"
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
""" """
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
@@ -774,27 +828,18 @@ class BloomModel(BloomPreTrainedModel):
else past_seen_tokens + sequence_length + 1 else past_seen_tokens + sequence_length + 1
) )
if attention_mask is not None and attention_mask.dim() == 4: # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
if attention_mask.max() != 0: attention_mask,
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") sequence_length=sequence_length,
causal_mask = attention_mask target_length=target_length,
else: dtype=dtype,
causal_mask = torch.full( device=device,
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device min_dtype=min_dtype,
) cache_position=cache_position,
if sequence_length != 1: batch_size=input_tensor.shape[0],
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
) )
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None

View File

@@ -50,6 +50,60 @@ if is_flash_attn_2_available():
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "ChameleonConfig" _CONFIG_FOR_DOC = "ChameleonConfig"
@@ -1417,27 +1471,18 @@ class ChameleonModel(ChameleonPreTrainedModel):
else past_seen_tokens + sequence_length + 1 else past_seen_tokens + sequence_length + 1
) )
if attention_mask is not None and attention_mask.dim() == 4: # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
if attention_mask.max() != 0: attention_mask,
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") sequence_length=sequence_length,
causal_mask = attention_mask target_length=target_length,
else: dtype=dtype,
causal_mask = torch.full( device=device,
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device min_dtype=min_dtype,
) cache_position=cache_position,
if sequence_length != 1: batch_size=input_tensor.shape[0],
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
) )
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None

View File

@@ -59,6 +59,60 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "CohereConfig" _CONFIG_FOR_DOC = "CohereConfig"
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class CohereLayerNorm(nn.Module): class CohereLayerNorm(nn.Module):
def __init__(self, hidden_size=None, eps=1e-5, bias=False): def __init__(self, hidden_size=None, eps=1e-5, bias=False):
"""The hidden size can be a tuple or an int. The tuple is used for QKNorm to normalize across head_dim""" """The hidden size can be a tuple or an int. The tuple is used for QKNorm to normalize across head_dim"""
@@ -895,27 +949,18 @@ class CohereModel(CoherePreTrainedModel):
else past_seen_tokens + sequence_length + 1 else past_seen_tokens + sequence_length + 1
) )
if attention_mask is not None and attention_mask.dim() == 4: # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
if attention_mask.max() != 0: attention_mask,
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") sequence_length=sequence_length,
causal_mask = attention_mask target_length=target_length,
else: dtype=dtype,
causal_mask = torch.full( device=device,
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device min_dtype=min_dtype,
) cache_position=cache_position,
if sequence_length != 1: batch_size=input_tensor.shape[0],
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
) )
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
@@ -1082,11 +1127,36 @@ class CohereForCausalLM(CoherePreTrainedModel):
if past_key_values: if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs = {"input_ids": input_ids}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs.update( model_inputs.update(
{ {

View File

@@ -45,6 +45,60 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "DbrxConfig" _CONFIG_FOR_DOC = "DbrxConfig"
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with Gemma->Dbrx # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with Gemma->Dbrx
class DbrxRotaryEmbedding(nn.Module): class DbrxRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
@@ -1147,27 +1201,18 @@ class DbrxModel(DbrxPreTrainedModel):
else past_seen_tokens + sequence_length + 1 else past_seen_tokens + sequence_length + 1
) )
if attention_mask is not None and attention_mask.dim() == 4: # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
if attention_mask.max() != 0: attention_mask,
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") sequence_length=sequence_length,
causal_mask = attention_mask target_length=target_length,
else: dtype=dtype,
causal_mask = torch.full( device=device,
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device min_dtype=min_dtype,
) cache_position=cache_position,
if sequence_length != 1: batch_size=input_tensor.shape[0],
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
) )
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
@@ -1353,11 +1398,36 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
if past_key_values: if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs = {"input_ids": input_ids}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs.update( model_inputs.update(
{ {

View File

@@ -52,6 +52,60 @@ from .configuration_gemma import GemmaConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class GemmaRMSNorm(nn.Module): class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6): def __init__(self, dim: int, eps: float = 1e-6):
super().__init__() super().__init__()
@@ -914,26 +968,16 @@ class GemmaModel(GemmaPreTrainedModel):
else past_seen_tokens + sequence_length + 1 else past_seen_tokens + sequence_length + 1
) )
if attention_mask is not None and attention_mask.dim() == 4: # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
if attention_mask.max() != 0: attention_mask,
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") sequence_length=sequence_length,
causal_mask = attention_mask target_length=target_length,
else: dtype=dtype,
causal_mask = torch.full( device=device,
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device min_dtype=min_dtype,
) cache_position=cache_position,
if sequence_length != 1: batch_size=input_tensor.shape[0],
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
) )
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
@@ -1094,12 +1138,36 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
position_ids.masked_fill_(attention_mask == 0, 1) position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values: if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs = {"input_ids": input_ids}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs.update( model_inputs.update(
{ {

View File

@@ -54,6 +54,60 @@ if is_flash_attn_2_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class Gemma2RMSNorm(nn.Module): class Gemma2RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6): def __init__(self, dim: int, eps: float = 1e-6):
super().__init__() super().__init__()
@@ -845,26 +899,16 @@ class Gemma2Model(Gemma2PreTrainedModel):
else: else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
if attention_mask is not None and attention_mask.dim() == 4: # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
if attention_mask.max() != 0: attention_mask,
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") sequence_length=sequence_length,
causal_mask = attention_mask target_length=target_length,
else: dtype=dtype,
causal_mask = torch.full( device=device,
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device min_dtype=min_dtype,
) cache_position=cache_position,
if sequence_length != 1: batch_size=input_tensor.shape[0],
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
) )
return causal_mask return causal_mask
@@ -1024,12 +1068,41 @@ class Gemma2ForCausalLM(Gemma2PreTrainedModel):
position_ids.masked_fill_(attention_mask == 0, 1) position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values: if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s
# `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride
# during the decoding. Here, simply using `.contiguous()` is not sufficient as in the
# batch size = 1 case, `position_ids` is already contiguous but with varying stride
# which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases # The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format)}
if isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs.update( model_inputs.update(
{ {

View File

@@ -25,9 +25,7 @@ from torch.nn import functional as F
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import ( from ...modeling_attn_mask_utils import AttentionMaskConverter
AttentionMaskConverter,
)
from ...modeling_outputs import ( from ...modeling_outputs import (
MoeCausalLMOutputWithPast, MoeCausalLMOutputWithPast,
MoeModelOutputWithPast, MoeModelOutputWithPast,
@@ -54,6 +52,60 @@ _CHECKPOINT_FOR_DOC = "jetmoe"
_CONFIG_FOR_DOC = "JetMoeConfig" _CONFIG_FOR_DOC = "JetMoeConfig"
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
# Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func # Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func
def load_balancing_loss_func( def load_balancing_loss_func(
gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
@@ -1121,27 +1173,18 @@ class JetMoeModel(JetMoePreTrainedModel):
else past_seen_tokens + sequence_length + 1 else past_seen_tokens + sequence_length + 1
) )
if attention_mask is not None and attention_mask.dim() == 4: # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
if attention_mask.max() != 0: attention_mask,
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") sequence_length=sequence_length,
causal_mask = attention_mask target_length=target_length,
else: dtype=dtype,
causal_mask = torch.full( device=device,
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device min_dtype=min_dtype,
) cache_position=cache_position,
if sequence_length != 1: batch_size=input_tensor.shape[0],
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
) )
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None

View File

@@ -55,6 +55,59 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlamaConfig" _CONFIG_FOR_DOC = "LlamaConfig"
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class LlamaRMSNorm(nn.Module): class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """
@@ -1029,27 +1082,18 @@ class LlamaModel(LlamaPreTrainedModel):
else past_seen_tokens + sequence_length + 1 else past_seen_tokens + sequence_length + 1
) )
if attention_mask is not None and attention_mask.dim() == 4: # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
if attention_mask.max() != 0: attention_mask,
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") sequence_length=sequence_length,
causal_mask = attention_mask target_length=target_length,
else: dtype=dtype,
causal_mask = torch.full( device=device,
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device min_dtype=min_dtype,
) cache_position=cache_position,
if sequence_length != 1: batch_size=input_tensor.shape[0],
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
) )
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
@@ -1216,11 +1260,36 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
if past_key_values: if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs = {"input_ids": input_ids}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs.update( model_inputs.update(
{ {

View File

@@ -1072,7 +1072,6 @@ class MistralForCausalLM(MistralPreTrainedModel):
attentions=outputs.attentions, attentions=outputs.attentions,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,
input_ids, input_ids,
@@ -1100,6 +1099,9 @@ class MistralForCausalLM(MistralPreTrainedModel):
if past_key_values: if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}

View File

@@ -30,10 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import ( from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
)
from ...modeling_outputs import ( from ...modeling_outputs import (
MoeCausalLMOutputWithPast, MoeCausalLMOutputWithPast,
MoeModelOutputWithPast, MoeModelOutputWithPast,
@@ -70,6 +67,60 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MixtralConfig" _CONFIG_FOR_DOC = "MixtralConfig"
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
def load_balancing_loss_func( def load_balancing_loss_func(
gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
) -> float: ) -> float:
@@ -1104,27 +1155,18 @@ class MixtralModel(MixtralPreTrainedModel):
else past_seen_tokens + sequence_length + 1 else past_seen_tokens + sequence_length + 1
) )
if attention_mask is not None and attention_mask.dim() == 4: # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
if attention_mask.max() != 0: attention_mask,
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") sequence_length=sequence_length,
causal_mask = attention_mask target_length=target_length,
else: dtype=dtype,
causal_mask = torch.full( device=device,
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device min_dtype=min_dtype,
) cache_position=cache_position,
if sequence_length != 1: batch_size=input_tensor.shape[0],
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
) )
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None

View File

@@ -57,6 +57,60 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "OlmoConfig" _CONFIG_FOR_DOC = "OlmoConfig"
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class OlmoLayerNorm(nn.Module): class OlmoLayerNorm(nn.Module):
"""LayerNorm but with no learnable weight or bias.""" """LayerNorm but with no learnable weight or bias."""
@@ -941,27 +995,18 @@ class OlmoModel(OlmoPreTrainedModel):
else past_seen_tokens + sequence_length + 1 else past_seen_tokens + sequence_length + 1
) )
if attention_mask is not None and attention_mask.dim() == 4: # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
if attention_mask.max() != 0: attention_mask,
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") sequence_length=sequence_length,
causal_mask = attention_mask target_length=target_length,
else: dtype=dtype,
causal_mask = torch.full( device=device,
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device min_dtype=min_dtype,
) cache_position=cache_position,
if sequence_length != 1: batch_size=input_tensor.shape[0],
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
) )
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
@@ -1126,11 +1171,36 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
if past_key_values: if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs = {"input_ids": input_ids}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs.update( model_inputs.update(
{ {

View File

@@ -29,9 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import ( from ...modeling_attn_mask_utils import AttentionMaskConverter
AttentionMaskConverter,
)
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
@@ -48,6 +46,60 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "PersimmonConfig" _CONFIG_FOR_DOC = "PersimmonConfig"
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Persimmon # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Persimmon
class PersimmonRotaryEmbedding(nn.Module): class PersimmonRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
@@ -755,27 +807,18 @@ class PersimmonModel(PersimmonPreTrainedModel):
else past_seen_tokens + sequence_length + 1 else past_seen_tokens + sequence_length + 1
) )
if attention_mask is not None and attention_mask.dim() == 4: # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
if attention_mask.max() != 0: attention_mask,
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") sequence_length=sequence_length,
causal_mask = attention_mask target_length=target_length,
else: dtype=dtype,
causal_mask = torch.full( device=device,
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device min_dtype=min_dtype,
) cache_position=cache_position,
if sequence_length != 1: batch_size=input_tensor.shape[0],
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
) )
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
@@ -945,11 +988,36 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
if past_key_values: if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs = {"input_ids": input_ids}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs.update( model_inputs.update(
{ {

View File

@@ -26,9 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import ( from ...modeling_attn_mask_utils import AttentionMaskConverter
AttentionMaskConverter,
)
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
@@ -59,6 +57,60 @@ _CHECKPOINT_FOR_DOC = "microsoft/phi-1"
_CONFIG_FOR_DOC = "PhiConfig" _CONFIG_FOR_DOC = "PhiConfig"
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Phi # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Phi
class PhiRotaryEmbedding(nn.Module): class PhiRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
@@ -1039,27 +1091,18 @@ class PhiModel(PhiPreTrainedModel):
else past_seen_tokens + sequence_length + 1 else past_seen_tokens + sequence_length + 1
) )
if attention_mask is not None and attention_mask.dim() == 4: # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
if attention_mask.max() != 0: attention_mask,
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") sequence_length=sequence_length,
causal_mask = attention_mask target_length=target_length,
else: dtype=dtype,
causal_mask = torch.full( device=device,
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device min_dtype=min_dtype,
) cache_position=cache_position,
if sequence_length != 1: batch_size=input_tensor.shape[0],
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
) )
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
@@ -1230,11 +1273,36 @@ class PhiForCausalLM(PhiPreTrainedModel):
if past_key_values: if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs = {"input_ids": input_ids}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs.update( model_inputs.update(
{ {

View File

@@ -26,9 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import ( from ...modeling_attn_mask_utils import AttentionMaskConverter
AttentionMaskConverter,
)
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
@@ -57,6 +55,60 @@ _CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct"
_CONFIG_FOR_DOC = "Phi3Config" _CONFIG_FOR_DOC = "Phi3Config"
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3 # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3
class Phi3RMSNorm(nn.Module): class Phi3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
@@ -1082,27 +1134,18 @@ class Phi3Model(Phi3PreTrainedModel):
else past_seen_tokens + sequence_length + 1 else past_seen_tokens + sequence_length + 1
) )
if attention_mask is not None and attention_mask.dim() == 4: # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
if attention_mask.max() != 0: attention_mask,
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") sequence_length=sequence_length,
causal_mask = attention_mask target_length=target_length,
else: dtype=dtype,
causal_mask = torch.full( device=device,
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device min_dtype=min_dtype,
) cache_position=cache_position,
if sequence_length != 1: batch_size=input_tensor.shape[0],
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
) )
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
@@ -1273,11 +1316,36 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
if past_key_values: if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs = {"input_ids": input_ids}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs.update( model_inputs.update(
{ {

View File

@@ -29,9 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import ( from ...modeling_attn_mask_utils import AttentionMaskConverter
AttentionMaskConverter,
)
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
@@ -61,6 +59,60 @@ _CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
_CONFIG_FOR_DOC = "Qwen2Config" _CONFIG_FOR_DOC = "Qwen2Config"
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2 # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2
class Qwen2RMSNorm(nn.Module): class Qwen2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
@@ -941,27 +993,18 @@ class Qwen2Model(Qwen2PreTrainedModel):
else past_seen_tokens + sequence_length + 1 else past_seen_tokens + sequence_length + 1
) )
if attention_mask is not None and attention_mask.dim() == 4: # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
if attention_mask.max() != 0: attention_mask,
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") sequence_length=sequence_length,
causal_mask = attention_mask target_length=target_length,
else: dtype=dtype,
causal_mask = torch.full( device=device,
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device min_dtype=min_dtype,
) cache_position=cache_position,
if sequence_length != 1: batch_size=input_tensor.shape[0],
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
) )
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
@@ -1125,11 +1168,36 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
if past_key_values: if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs = {"input_ids": input_ids}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs.update( model_inputs.update(
{ {

View File

@@ -30,9 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import ( from ...modeling_attn_mask_utils import AttentionMaskConverter
AttentionMaskConverter,
)
from ...modeling_outputs import ( from ...modeling_outputs import (
MoeCausalLMOutputWithPast, MoeCausalLMOutputWithPast,
MoeModelOutputWithPast, MoeModelOutputWithPast,
@@ -60,6 +58,60 @@ _CHECKPOINT_FOR_DOC = "Qwen/Qwen1.5-MoE-A2.7B"
_CONFIG_FOR_DOC = "Qwen2MoeConfig" _CONFIG_FOR_DOC = "Qwen2MoeConfig"
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
# Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func # Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func
def load_balancing_loss_func( def load_balancing_loss_func(
gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
@@ -1114,27 +1166,18 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
else past_seen_tokens + sequence_length + 1 else past_seen_tokens + sequence_length + 1
) )
if attention_mask is not None and attention_mask.dim() == 4: # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
if attention_mask.max() != 0: attention_mask,
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") sequence_length=sequence_length,
causal_mask = attention_mask target_length=target_length,
else: dtype=dtype,
causal_mask = torch.full( device=device,
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device min_dtype=min_dtype,
) cache_position=cache_position,
if sequence_length != 1: batch_size=input_tensor.shape[0],
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
) )
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
@@ -1321,11 +1364,36 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel):
if past_key_values: if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs = {"input_ids": input_ids}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs.update( model_inputs.update(
{ {

View File

@@ -29,9 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import ( from ...modeling_attn_mask_utils import AttentionMaskConverter
AttentionMaskConverter,
)
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
@@ -59,6 +57,60 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "StableLmConfig" _CONFIG_FOR_DOC = "StableLmConfig"
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->StableLm # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->StableLm
class StableLmRotaryEmbedding(nn.Module): class StableLmRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
@@ -1032,27 +1084,18 @@ class StableLmModel(StableLmPreTrainedModel):
else past_seen_tokens + sequence_length + 1 else past_seen_tokens + sequence_length + 1
) )
if attention_mask is not None and attention_mask.dim() == 4: # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
if attention_mask.max() != 0: attention_mask,
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") sequence_length=sequence_length,
causal_mask = attention_mask target_length=target_length,
else: dtype=dtype,
causal_mask = torch.full( device=device,
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device min_dtype=min_dtype,
) cache_position=cache_position,
if sequence_length != 1: batch_size=input_tensor.shape[0],
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
) )
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
@@ -1223,11 +1266,36 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
if past_key_values: if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs = {"input_ids": input_ids}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs.update( model_inputs.update(
{ {

View File

@@ -29,9 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import ( from ...modeling_attn_mask_utils import AttentionMaskConverter
AttentionMaskConverter,
)
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
@@ -59,6 +57,60 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Starcoder2Config" _CONFIG_FOR_DOC = "Starcoder2Config"
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Starcoder2 # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Starcoder2
class Starcoder2RotaryEmbedding(nn.Module): class Starcoder2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
@@ -915,27 +967,18 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
else past_seen_tokens + sequence_length + 1 else past_seen_tokens + sequence_length + 1
) )
if attention_mask is not None and attention_mask.dim() == 4: # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
if attention_mask.max() != 0: attention_mask,
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") sequence_length=sequence_length,
causal_mask = attention_mask target_length=target_length,
else: dtype=dtype,
causal_mask = torch.full( device=device,
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device min_dtype=min_dtype,
) cache_position=cache_position,
if sequence_length != 1: batch_size=input_tensor.shape[0],
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
) )
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None
@@ -1101,11 +1144,36 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
if past_key_values: if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases model_inputs = {"input_ids": input_ids}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if inputs_embeds is not None:
batch_size, sequence_length = inputs_embeds.shape
device = inputs_embeds.device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs.update( model_inputs.update(
{ {

View File

@@ -59,6 +59,60 @@ _CONFIG_FOR_DOC = "WhisperConfig"
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny" _CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor: def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor:
"""Returns sinusoids for positional embedding""" """Returns sinusoids for positional embedding"""
if channels % 2 != 0: if channels % 2 != 0:
@@ -1412,27 +1466,18 @@ class WhisperDecoder(WhisperPreTrainedModel):
else past_seen_tokens + sequence_length + 1 else past_seen_tokens + sequence_length + 1
) )
if attention_mask is not None and attention_mask.dim() == 4: # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
if attention_mask.max() != 0: attention_mask,
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") sequence_length=sequence_length,
causal_mask = attention_mask target_length=target_length,
else: dtype=dtype,
causal_mask = torch.full( device=device,
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device min_dtype=min_dtype,
) cache_position=cache_position,
if sequence_length != 1: batch_size=input_tensor.shape[0],
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
) )
if ( if (
self.config._attn_implementation == "sdpa" self.config._attn_implementation == "sdpa"
and attention_mask is not None and attention_mask is not None

View File

@@ -816,7 +816,7 @@ class GemmaIntegrationTest(unittest.TestCase):
# Dynamic Cache # Dynamic Cache
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION[8], dynamic_text) # Both GPU architectures have the same output self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) # Both GPU architectures have the same output
# Static Cache # Static Cache
generated_ids = model.generate( generated_ids = model.generate(

View File

@@ -22,6 +22,7 @@ import os.path
import random import random
import re import re
import tempfile import tempfile
import time
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
@@ -37,6 +38,7 @@ from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
AutoTokenizer, AutoTokenizer,
GenerationConfig,
PretrainedConfig, PretrainedConfig,
PreTrainedModel, PreTrainedModel,
is_torch_available, is_torch_available,
@@ -4605,7 +4607,6 @@ class ModelTesterMixin:
tokenizer = AutoTokenizer.from_pretrained(ckpt) tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(torch_device) model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(torch_device)
model.generation_config.max_new_tokens = 4
model.generation_config.max_new_tokens = 4 model.generation_config.max_new_tokens = 4
model.generation_config.cache_implementation = "static" model.generation_config.cache_implementation = "static"
@@ -4617,6 +4618,66 @@ class ModelTesterMixin:
for i in range(n_iter): for i in range(n_iter):
_ = model.generate(**input_ids, do_sample=False) _ = model.generate(**input_ids, do_sample=False)
@slow
@require_torch_gpu # Testing cuda graphs.
@require_read_token
def test_compile_cuda_graph_time(self):
if version.parse(torch.__version__) < version.parse("2.3"):
self.skipTest(reason="This test requires torch >= 2.3 to run.")
# TODO felix: All models supporting `StaticCache` or `torch.compile` should be tested.
# At the moment, only llama, gemma and gemma2 are tested here!
if not hasattr(self, "_torch_compile_test_ckpt"):
self.skipTest(f"{self.__class__.__name__} doesn't have the attribute `_torch_compile_test_ckpt`.")
ckpt = self._torch_compile_test_ckpt
os.environ["TOKENIZERS_PARALLELISM"] = "false"
tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to(torch_device)
cache_implementation = "static"
if model.config.model_type == "gemma2":
cache_implementation = "hybrid"
new_tokens = 50
gen_config = GenerationConfig(
max_new_tokens=new_tokens,
min_new_tokens=new_tokens,
use_cache=True,
pad_token_id=tokenizer.pad_token_id,
num_beams=1,
do_sample=False,
eos_token_id=None, # This is required for min_new_tokens to actually have an effect.
)
model.generation_config.eos_token_id = None # greedy_search falls back on this eos_token_id that we need to set to None as well for min_new_tokens to have an effect.
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
inp = tokenizer("Why cats are cute?", return_tensors="pt").to(torch_device)
# First run: the first run warms up each graph, which does things like CuBlas or Triton benchmarking
start = time.perf_counter()
_ = model.generate(**inp, generation_config=gen_config, cache_implementation=cache_implementation)
end = time.perf_counter()
graph_warmup_time = end - start
# Second run: CUDA Graph recording, and replays it
start = time.perf_counter()
_ = model.generate(**inp, generation_config=gen_config, cache_implementation=cache_implementation)
end = time.perf_counter()
record_time = end - start
# Finally: we hit the optimized, CUDA Graph replay path
start = time.perf_counter()
_ = model.generate(**inp, generation_config=gen_config, cache_implementation=cache_implementation)
end = time.perf_counter()
opt_time = end - start
# For the recording step, we expect only two cuda graphs and this step should be much faster than the first.
self.assertTrue(record_time < 0.15 * graph_warmup_time)
self.assertTrue(opt_time < record_time)
global_rng = random.Random() global_rng = random.Random()