Paligemma causal attention mask (#30967)
* PaliGemma working causal attention * Formatting * Style * Docstrings + remove commented code * Update docstring for PaliGemma Config * PaliGemma - add separator ind to model/labels * Refactor + docstring paligemma processor method * Style * return token type ids when tokenizing labels * use token type ids when building causal mask * add token type ids to tester * remove separator from config * fix style * don't ignore separator * add processor documentation * simplify tokenization * fix causal mask * style * fix label propagation, revert suffix naming * fix style * fix labels tokenization * [run-slow]paligemma * add eos if suffixes are present * [run-slow]paligemma * [run-slow]paligemma * add misssing tokens to fast version * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix style * [run-slow]paligemma --------- Co-authored-by: Peter Robicheaux <peter@roboflow.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -283,9 +283,14 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
|||||||
self.vocab_size = model_embeds.num_embeddings
|
self.vocab_size = model_embeds.num_embeddings
|
||||||
return model_embeds
|
return model_embeds
|
||||||
|
|
||||||
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
|
def _merge_input_ids_with_image_features(
|
||||||
|
self, image_features, inputs_embeds, input_ids, attention_mask, labels, token_type_ids, cache_position
|
||||||
|
):
|
||||||
_, _, embed_dim = image_features.shape
|
_, _, embed_dim = image_features.shape
|
||||||
batch_size, sequence_length = input_ids.shape
|
batch_size, sequence_length = input_ids.shape
|
||||||
|
dtype, device = inputs_embeds.dtype, inputs_embeds.device
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
|
||||||
scaled_image_features = image_features / (self.config.hidden_size**0.5)
|
scaled_image_features = image_features / (self.config.hidden_size**0.5)
|
||||||
final_embedding = torch.zeros(
|
final_embedding = torch.zeros(
|
||||||
batch_size, sequence_length, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
batch_size, sequence_length, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
||||||
@@ -306,24 +311,43 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
|||||||
image_mask.unsqueeze(-1).expand_as(final_embedding), scaled_image_features
|
image_mask.unsqueeze(-1).expand_as(final_embedding), scaled_image_features
|
||||||
)
|
)
|
||||||
final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding)
|
final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding)
|
||||||
|
if attention_mask is not None:
|
||||||
|
position_ids = (attention_mask.cumsum(-1)).masked_fill_((attention_mask == 0), 1)
|
||||||
|
else:
|
||||||
|
position_ids = None
|
||||||
|
|
||||||
final_attention_mask_4d = attention_mask.unsqueeze(1).unsqueeze(2) * attention_mask.unsqueeze(1).unsqueeze(-1)
|
if token_type_ids is not None and labels is not None:
|
||||||
final_attention_mask_4d = final_attention_mask_4d.float().expand(
|
# we are training thus we need to create a full mask on the image + prefix but causal on suffix
|
||||||
-1, self.config.text_config.num_key_value_heads, -1, -1
|
target_length = cache_position[-1] + 1
|
||||||
|
causal_mask = torch.full(
|
||||||
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
if sequence_length != 1:
|
||||||
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
|
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1)
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
mask_length = attention_mask.shape[-1]
|
||||||
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
|
# unmask the prefill
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
token_type_ids[:, None, None, :] == 0, 0
|
||||||
|
)
|
||||||
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
padding_mask, min_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
# position_ids = torch.arange(0, sequence_length, device=input_ids.device).expand(batch_size, -1)
|
|
||||||
# position_ids = torch.where(input_ids == self.pad_token_id, torch.ones_like(position_ids), position_ids)
|
|
||||||
position_ids = (attention_mask.cumsum(-1)).masked_fill_((attention_mask == 0), 1)
|
|
||||||
|
|
||||||
if labels is not None:
|
|
||||||
final_labels = torch.full(
|
final_labels = torch.full(
|
||||||
(batch_size, sequence_length), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
(batch_size, sequence_length), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
||||||
)
|
)
|
||||||
final_labels = torch.where(input_ids != self.pad_token_id, labels, final_labels)
|
final_labels = torch.where(input_ids != self.pad_token_id, labels, final_labels)
|
||||||
else:
|
else:
|
||||||
|
causal_mask = attention_mask.unsqueeze(1).unsqueeze(2) * attention_mask.unsqueeze(1).unsqueeze(-1)
|
||||||
|
causal_mask = causal_mask.to(dtype).expand(-1, self.config.text_config.num_key_value_heads, -1, -1)
|
||||||
final_labels = None
|
final_labels = None
|
||||||
return final_embedding, final_attention_mask_4d, final_labels, position_ids
|
return final_embedding, causal_mask, final_labels, position_ids
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||||
@@ -334,6 +358,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
|||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
|
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
|
||||||
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
@@ -397,8 +422,10 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
|||||||
selected_image_feature = image_outputs.last_hidden_state
|
selected_image_feature = image_outputs.last_hidden_state
|
||||||
image_features = self.multi_modal_projector(selected_image_feature)
|
image_features = self.multi_modal_projector(selected_image_feature)
|
||||||
|
|
||||||
|
if cache_position is None:
|
||||||
|
cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
|
||||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
||||||
image_features, inputs_embeds, input_ids, attention_mask, labels
|
image_features, inputs_embeds, input_ids, attention_mask, labels, token_type_ids, cache_position
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@@ -487,6 +514,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
|||||||
cache_position=None,
|
cache_position=None,
|
||||||
pixel_values=None,
|
pixel_values=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
token_type_ids=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
past_length = 0
|
past_length = 0
|
||||||
@@ -545,6 +573,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
|||||||
"use_cache": kwargs.get("use_cache"),
|
"use_cache": kwargs.get("use_cache"),
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"pixel_values": pixel_values,
|
"pixel_values": pixel_values,
|
||||||
|
"token_type_ids": token_type_ids,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -163,6 +163,8 @@ class PaliGemmaVisionText2TextModelTester:
|
|||||||
"pixel_values": pixel_values,
|
"pixel_values": pixel_values,
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
|
"labels": input_ids,
|
||||||
|
"token_type_ids": torch.zeros_like(input_ids),
|
||||||
}
|
}
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user