Fix aria tests (#39879)

* fix aria tests

* awful bug

* fix copies

* fix tests

* fix style

* revert this
This commit is contained in:
Raushan Turganbay
2025-08-05 13:48:47 +02:00
committed by GitHub
parent 6e4a9a5b43
commit 2589a52c5c
3 changed files with 4 additions and 52 deletions

View File

@@ -1014,18 +1014,9 @@ class AriaModel(AriaPreTrainedModel):
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[tuple, AriaModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
@@ -1037,7 +1028,7 @@ class AriaModel(AriaPreTrainedModel):
vision_feature_layer=self.config.vision_feature_layer,
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
special_image_mask = self._get_image_mask(
special_image_mask = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
@@ -1048,9 +1039,6 @@ class AriaModel(AriaPreTrainedModel):
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
@@ -1156,9 +1144,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
@@ -1223,12 +1208,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
>>> print(generated_texts[1])
Assistant: The bridge is in San Francisco.
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
@@ -1238,9 +1217,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)

View File

@@ -1414,18 +1414,9 @@ class AriaModel(LlavaModel):
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[tuple, AriaModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
@@ -1437,7 +1428,7 @@ class AriaModel(LlavaModel):
vision_feature_layer=self.config.vision_feature_layer,
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
special_image_mask = self._get_image_mask(
special_image_mask = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
@@ -1448,9 +1439,6 @@ class AriaModel(LlavaModel):
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**kwargs,
)
@@ -1498,9 +1486,6 @@ class AriaForConditionalGeneration(LlavaForConditionalGeneration):
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
@@ -1565,12 +1550,6 @@ class AriaForConditionalGeneration(LlavaForConditionalGeneration):
>>> print(generated_texts[1])
Assistant: The bridge is in San Francisco.
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
@@ -1580,9 +1559,6 @@ class AriaForConditionalGeneration(LlavaForConditionalGeneration):
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)

View File

@@ -137,8 +137,8 @@ class AriaVisionText2TextModelTester:
def get_config(self):
return AriaConfig(
text_config=self.text_config,
vision_config=self.vision_config,
text_config=self.text_config.to_dict(),
vision_config=self.vision_config.to_dict(),
ignore_index=self.ignore_index,
image_token_index=self.image_token_index,
projector_hidden_act=self.projector_hidden_act,