Fix aria tests (#39879)
* fix aria tests * awful bug * fix copies * fix tests * fix style * revert this
This commit is contained in:
committed by
GitHub
parent
6e4a9a5b43
commit
2589a52c5c
@@ -1014,18 +1014,9 @@ class AriaModel(AriaPreTrainedModel):
|
|||||||
past_key_values: Optional[Cache] = None,
|
past_key_values: Optional[Cache] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs: Unpack[FlashAttentionKwargs],
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Union[tuple, AriaModelOutputWithPast]:
|
) -> Union[tuple, AriaModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
@@ -1037,7 +1028,7 @@ class AriaModel(AriaPreTrainedModel):
|
|||||||
vision_feature_layer=self.config.vision_feature_layer,
|
vision_feature_layer=self.config.vision_feature_layer,
|
||||||
)
|
)
|
||||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
special_image_mask = self._get_image_mask(
|
special_image_mask = self.get_placeholder_mask(
|
||||||
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
|
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
|
||||||
)
|
)
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||||
@@ -1048,9 +1039,6 @@ class AriaModel(AriaPreTrainedModel):
|
|||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=True,
|
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@@ -1156,9 +1144,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
|||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs: Unpack[TransformersKwargs],
|
**kwargs: Unpack[TransformersKwargs],
|
||||||
@@ -1223,12 +1208,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
|||||||
>>> print(generated_texts[1])
|
>>> print(generated_texts[1])
|
||||||
Assistant: The bridge is in San Francisco.
|
Assistant: The bridge is in San Francisco.
|
||||||
```"""
|
```"""
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
@@ -1238,9 +1217,6 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
|||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1414,18 +1414,9 @@ class AriaModel(LlavaModel):
|
|||||||
past_key_values: Optional[Cache] = None,
|
past_key_values: Optional[Cache] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs: Unpack[FlashAttentionKwargs],
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
) -> Union[tuple, AriaModelOutputWithPast]:
|
) -> Union[tuple, AriaModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||||
|
|
||||||
@@ -1437,7 +1428,7 @@ class AriaModel(LlavaModel):
|
|||||||
vision_feature_layer=self.config.vision_feature_layer,
|
vision_feature_layer=self.config.vision_feature_layer,
|
||||||
)
|
)
|
||||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
special_image_mask = self._get_image_mask(
|
special_image_mask = self.get_placeholder_mask(
|
||||||
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
|
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
|
||||||
)
|
)
|
||||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||||
@@ -1448,9 +1439,6 @@ class AriaModel(LlavaModel):
|
|||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=True,
|
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@@ -1498,9 +1486,6 @@ class AriaForConditionalGeneration(LlavaForConditionalGeneration):
|
|||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs: Unpack[TransformersKwargs],
|
**kwargs: Unpack[TransformersKwargs],
|
||||||
@@ -1565,12 +1550,6 @@ class AriaForConditionalGeneration(LlavaForConditionalGeneration):
|
|||||||
>>> print(generated_texts[1])
|
>>> print(generated_texts[1])
|
||||||
Assistant: The bridge is in San Francisco.
|
Assistant: The bridge is in San Francisco.
|
||||||
```"""
|
```"""
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
@@ -1580,9 +1559,6 @@ class AriaForConditionalGeneration(LlavaForConditionalGeneration):
|
|||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -137,8 +137,8 @@ class AriaVisionText2TextModelTester:
|
|||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return AriaConfig(
|
return AriaConfig(
|
||||||
text_config=self.text_config,
|
text_config=self.text_config.to_dict(),
|
||||||
vision_config=self.vision_config,
|
vision_config=self.vision_config.to_dict(),
|
||||||
ignore_index=self.ignore_index,
|
ignore_index=self.ignore_index,
|
||||||
image_token_index=self.image_token_index,
|
image_token_index=self.image_token_index,
|
||||||
projector_hidden_act=self.projector_hidden_act,
|
projector_hidden_act=self.projector_hidden_act,
|
||||||
|
|||||||
Reference in New Issue
Block a user