Paligemma causal attention mask (#30967)
* PaliGemma working causal attention * Formatting * Style * Docstrings + remove commented code * Update docstring for PaliGemma Config * PaliGemma - add separator ind to model/labels * Refactor + docstring paligemma processor method * Style * return token type ids when tokenizing labels * use token type ids when building causal mask * add token type ids to tester * remove separator from config * fix style * don't ignore separator * add processor documentation * simplify tokenization * fix causal mask * style * fix label propagation, revert suffix naming * fix style * fix labels tokenization * [run-slow]paligemma * add eos if suffixes are present * [run-slow]paligemma * [run-slow]paligemma * add misssing tokens to fast version * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix style * [run-slow]paligemma --------- Co-authored-by: Peter Robicheaux <peter@roboflow.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -283,9 +283,14 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
||||
self.vocab_size = model_embeds.num_embeddings
|
||||
return model_embeds
|
||||
|
||||
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
|
||||
def _merge_input_ids_with_image_features(
|
||||
self, image_features, inputs_embeds, input_ids, attention_mask, labels, token_type_ids, cache_position
|
||||
):
|
||||
_, _, embed_dim = image_features.shape
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
dtype, device = inputs_embeds.dtype, inputs_embeds.device
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
|
||||
scaled_image_features = image_features / (self.config.hidden_size**0.5)
|
||||
final_embedding = torch.zeros(
|
||||
batch_size, sequence_length, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
||||
@@ -306,24 +311,43 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
||||
image_mask.unsqueeze(-1).expand_as(final_embedding), scaled_image_features
|
||||
)
|
||||
final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding)
|
||||
if attention_mask is not None:
|
||||
position_ids = (attention_mask.cumsum(-1)).masked_fill_((attention_mask == 0), 1)
|
||||
else:
|
||||
position_ids = None
|
||||
|
||||
final_attention_mask_4d = attention_mask.unsqueeze(1).unsqueeze(2) * attention_mask.unsqueeze(1).unsqueeze(-1)
|
||||
final_attention_mask_4d = final_attention_mask_4d.float().expand(
|
||||
-1, self.config.text_config.num_key_value_heads, -1, -1
|
||||
)
|
||||
if token_type_ids is not None and labels is not None:
|
||||
# we are training thus we need to create a full mask on the image + prefix but causal on suffix
|
||||
target_length = cache_position[-1] + 1
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||
# unmask the prefill
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
token_type_ids[:, None, None, :] == 0, 0
|
||||
)
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
# position_ids = torch.arange(0, sequence_length, device=input_ids.device).expand(batch_size, -1)
|
||||
# position_ids = torch.where(input_ids == self.pad_token_id, torch.ones_like(position_ids), position_ids)
|
||||
position_ids = (attention_mask.cumsum(-1)).masked_fill_((attention_mask == 0), 1)
|
||||
|
||||
if labels is not None:
|
||||
final_labels = torch.full(
|
||||
(batch_size, sequence_length), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
||||
)
|
||||
final_labels = torch.where(input_ids != self.pad_token_id, labels, final_labels)
|
||||
else:
|
||||
causal_mask = attention_mask.unsqueeze(1).unsqueeze(2) * attention_mask.unsqueeze(1).unsqueeze(-1)
|
||||
causal_mask = causal_mask.to(dtype).expand(-1, self.config.text_config.num_key_value_heads, -1, -1)
|
||||
final_labels = None
|
||||
return final_embedding, final_attention_mask_4d, final_labels, position_ids
|
||||
return final_embedding, causal_mask, final_labels, position_ids
|
||||
|
||||
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
@@ -334,6 +358,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
@@ -397,8 +422,10 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
||||
selected_image_feature = image_outputs.last_hidden_state
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
|
||||
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
|
||||
image_features, inputs_embeds, input_ids, attention_mask, labels
|
||||
image_features, inputs_embeds, input_ids, attention_mask, labels, token_type_ids, cache_position
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -487,6 +514,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
||||
cache_position=None,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
**kwargs,
|
||||
):
|
||||
past_length = 0
|
||||
@@ -545,6 +573,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel):
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
"pixel_values": pixel_values,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -163,6 +163,8 @@ class PaliGemmaVisionText2TextModelTester:
|
||||
"pixel_values": pixel_values,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": input_ids,
|
||||
"token_type_ids": torch.zeros_like(input_ids),
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
Reference in New Issue
Block a user