Fix aria tests (#39879)
* fix aria tests * awful bug * fix copies * fix tests * fix style * revert this
This commit is contained in:
committed by
GitHub
parent
6e4a9a5b43
commit
2589a52c5c
@@ -1014,18 +1014,9 @@ class AriaModel(AriaPreTrainedModel):
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[tuple, AriaModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
@@ -1037,7 +1028,7 @@ class AriaModel(AriaPreTrainedModel):
|
||||
vision_feature_layer=self.config.vision_feature_layer,
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
special_image_mask = self._get_image_mask(
|
||||
special_image_mask = self.get_placeholder_mask(
|
||||
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
|
||||
)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
@@ -1048,9 +1039,6 @@ class AriaModel(AriaPreTrainedModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -1156,9 +1144,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
@@ -1223,12 +1208,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
>>> print(generated_texts[1])
|
||||
Assistant: The bridge is in San Francisco.
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
@@ -1238,9 +1217,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -1414,18 +1414,9 @@ class AriaModel(LlavaModel):
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[tuple, AriaModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
@@ -1437,7 +1428,7 @@ class AriaModel(LlavaModel):
|
||||
vision_feature_layer=self.config.vision_feature_layer,
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
special_image_mask = self._get_image_mask(
|
||||
special_image_mask = self.get_placeholder_mask(
|
||||
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
|
||||
)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
@@ -1448,9 +1439,6 @@ class AriaModel(LlavaModel):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -1498,9 +1486,6 @@ class AriaForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
@@ -1565,12 +1550,6 @@ class AriaForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
>>> print(generated_texts[1])
|
||||
Assistant: The bridge is in San Francisco.
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
@@ -1580,9 +1559,6 @@ class AriaForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -137,8 +137,8 @@ class AriaVisionText2TextModelTester:
|
||||
|
||||
def get_config(self):
|
||||
return AriaConfig(
|
||||
text_config=self.text_config,
|
||||
vision_config=self.vision_config,
|
||||
text_config=self.text_config.to_dict(),
|
||||
vision_config=self.vision_config.to_dict(),
|
||||
ignore_index=self.ignore_index,
|
||||
image_token_index=self.image_token_index,
|
||||
projector_hidden_act=self.projector_hidden_act,
|
||||
|
||||
Reference in New Issue
Block a user