From f92d354823ca38cbb71441f763711a66fe90d7d4 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Fri, 4 Oct 2024 22:45:37 +0200 Subject: [PATCH] fix red check-copies (#33964) --- .../models/paligemma/modeling_paligemma.py | 2 +- .../models/vipllava/modeling_vipllava.py | 2 +- .../models/zamba/modeling_zamba.py | 54 ------------------- tests/models/phimoe/test_modeling_phimoe.py | 2 +- 4 files changed, 3 insertions(+), 57 deletions(-) diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 560edca663..cba66aa3a8 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -46,7 +46,7 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "PaliGemmaConfig" -# Adapted from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position +# Adapted from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position # But Paligemma has no causal mask on prefix def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 9a2bb8fb0d..cacaaf6ac3 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -283,7 +283,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin) return model_embeds # Ignore copy - def get_image_features(self, pixel_values: torch.FloatTensor, vision_feature_layers: list[int]): + def get_image_features(self, pixel_values: torch.FloatTensor, vision_feature_layers: List[int]): image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) # For VIP-llava, the image features are computed this way diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 8c6c49a3a1..2363ed0495 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -77,60 +77,6 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "ZambaConfig" -# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - min_dtype: float, - cache_position: torch.Tensor, - batch_size: int, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Zamba class ZambaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): diff --git a/tests/models/phimoe/test_modeling_phimoe.py b/tests/models/phimoe/test_modeling_phimoe.py index 57e5e10fba..881967076e 100644 --- a/tests/models/phimoe/test_modeling_phimoe.py +++ b/tests/models/phimoe/test_modeling_phimoe.py @@ -154,7 +154,7 @@ class PhimoeModelTester: input_mask = None if self.use_input_mask: - input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device) + input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) token_type_ids = None if self.use_token_type_ids: