Add type hints for several pytorch models (batch-4) (#25749)

* Add type hints for MGP STR model

* Add missing type hints for plbart model

* Add type hints for Pix2struct model

* Add missing type hints to Rag model and tweak the docstring

* Add missing type hints to Sam model

* Add missing type hints to Swin2sr model

* Fix a type hint for Pix2StructTextModel

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

* Fix typo on Rag model docstring

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

* Fix linter

---------

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
David Reguera
2023-08-28 15:31:33 +02:00
committed by GitHub
parent ed915cff97
commit 886b6be081
6 changed files with 38 additions and 36 deletions

View File

@@ -380,7 +380,13 @@ class MgpstrModel(MgpstrPreTrainedModel):
return self.embeddings.proj return self.embeddings.proj
@add_start_docstrings_to_model_forward(MGP_STR_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MGP_STR_INPUTS_DOCSTRING)
def forward(self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None): def forward(
self,
pixel_values: torch.FloatTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -437,12 +443,12 @@ class MgpstrForSceneTextRecognition(MgpstrPreTrainedModel):
@replace_return_docstrings(output_type=MgpstrModelOutput, config_class=MgpstrConfig) @replace_return_docstrings(output_type=MgpstrModelOutput, config_class=MgpstrConfig)
def forward( def forward(
self, self,
pixel_values, pixel_values: torch.FloatTensor,
output_attentions=None, output_attentions: Optional[bool] = None,
output_a3_attentions=None, output_a3_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple[torch.FloatTensor], MgpstrModelOutput]:
r""" r"""
output_a3_attentions (`bool`, *optional*): output_a3_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of a3 modules. See `a3_attentions` under returned tensors Whether or not to return the attentions tensors of a3 modules. See `a3_attentions` under returned tensors

View File

@@ -1387,21 +1387,21 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
return_dict=None, return_dict: Optional[bool] = None,
**kwargs, **kwargs,
): ) -> Union[Tuple[torch.FloatTensor, ...], CausalLMOutputWithCrossAttentions]:
r""" r"""
Returns: Returns:

View File

@@ -1177,7 +1177,7 @@ class PLBartModel(PLBartPreTrainedModel):
encoder_outputs: Optional[List[torch.FloatTensor]] = None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
@@ -1302,7 +1302,7 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
encoder_outputs: Optional[List[torch.FloatTensor]] = None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,

View File

@@ -462,16 +462,12 @@ RAG_FORWARD_INPUTS_DOCSTRING = r"""
`question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information. `question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information.
context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*): context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
retriever. retriever. If the model was not initialized with a `retriever` ``context_input_ids` has to be provided to
the forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
If the model has is not initialized with a `retriever` ``context_input_ids` has to be provided to the context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`,*optional*, returned when *output_retrieved=True*):
forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`]. context_attention_mask Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
(`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, retriever. If the model has is not initialized with a `retriever` `context_attention_mask` has to be
returned when *output_retrieved=True*): Attention mask post-processed from the retrieved documents and the provided to the forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
question encoder `input_ids` by the retriever.
If the model has is not initialized with a `retriever` `context_attention_mask` has to be provided to the
forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
use_cache (`bool`, *optional*, defaults to `True`): use_cache (`bool`, *optional*, defaults to `True`):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`). `past_key_values`).
@@ -545,7 +541,7 @@ class RagModel(RagPreTrainedModel):
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
doc_scores: Optional[torch.FloatTensor] = None, doc_scores: Optional[torch.FloatTensor] = None,
context_input_ids: Optional[torch.LongTensor] = None, context_input_ids: Optional[torch.LongTensor] = None,
context_attention_mask=None, context_attention_mask: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,

View File

@@ -1296,7 +1296,7 @@ class SamModel(SamPreTrainedModel):
target_embedding: Optional[torch.FloatTensor] = None, target_embedding: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
**kwargs, **kwargs,
) -> List[Dict[str, torch.Tensor]]: ) -> List[Dict[str, torch.Tensor]]:
r""" r"""

View File

@@ -903,7 +903,7 @@ class Swin2SRModel(Swin2SRPreTrainedModel):
) )
def forward( def forward(
self, self,
pixel_values, pixel_values: torch.FloatTensor,
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,