Add Optional to types (#37163)
Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
@@ -902,7 +902,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
|
||||
@add_start_docstrings_to_model_forward(ARIA_TEXT_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
@@ -1189,7 +1189,7 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
@@ -1303,7 +1303,7 @@ class AriaCausalLMOutputWithPast(ModelOutput):
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
logits: Optional[torch.FloatTensor] = None
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
@@ -1424,7 +1424,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
pixel_mask: torch.FloatTensor = None,
|
||||
pixel_mask: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: int = -1,
|
||||
):
|
||||
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
|
||||
@@ -1446,9 +1446,9 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
@replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
pixel_mask: torch.LongTensor = None,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
pixel_mask: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
|
||||
@@ -1402,7 +1402,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
def get_image_features(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
pixel_mask: torch.FloatTensor = None,
|
||||
pixel_mask: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: int = -1,
|
||||
):
|
||||
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
|
||||
@@ -1424,9 +1424,9 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
@replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=AriaConfig)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
pixel_mask: torch.LongTensor = None,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
pixel_mask: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
|
||||
Reference in New Issue
Block a user