Add type hints for several pytorch models (batch-4) (#25749)
* Add type hints for MGP STR model * Add missing type hints for plbart model * Add type hints for Pix2struct model * Add missing type hints to Rag model and tweak the docstring * Add missing type hints to Sam model * Add missing type hints to Swin2sr model * Fix a type hint for Pix2StructTextModel Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> * Fix typo on Rag model docstring Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> * Fix linter --------- Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
@@ -380,7 +380,13 @@ class MgpstrModel(MgpstrPreTrainedModel):
|
|||||||
return self.embeddings.proj
|
return self.embeddings.proj
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(MGP_STR_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(MGP_STR_INPUTS_DOCSTRING)
|
||||||
def forward(self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None):
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
@@ -437,12 +443,12 @@ class MgpstrForSceneTextRecognition(MgpstrPreTrainedModel):
|
|||||||
@replace_return_docstrings(output_type=MgpstrModelOutput, config_class=MgpstrConfig)
|
@replace_return_docstrings(output_type=MgpstrModelOutput, config_class=MgpstrConfig)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values,
|
pixel_values: torch.FloatTensor,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_a3_attentions=None,
|
output_a3_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
) -> Union[Tuple[torch.FloatTensor], MgpstrModelOutput]:
|
||||||
r"""
|
r"""
|
||||||
output_a3_attentions (`bool`, *optional*):
|
output_a3_attentions (`bool`, *optional*):
|
||||||
Whether or not to return the attentions tensors of a3 modules. See `a3_attentions` under returned tensors
|
Whether or not to return the attentions tensors of a3 modules. See `a3_attentions` under returned tensors
|
||||||
|
|||||||
@@ -1387,21 +1387,21 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
|
|||||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask=None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
inputs_embeds=None,
|
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||||
head_mask=None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_values=None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
use_cache=None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions=None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states=None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
labels=None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
) -> Union[Tuple[torch.FloatTensor, ...], CausalLMOutputWithCrossAttentions]:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
|
|||||||
@@ -1177,7 +1177,7 @@ class PLBartModel(PLBartPreTrainedModel):
|
|||||||
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
@@ -1302,7 +1302,7 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
|
|||||||
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
encoder_outputs: Optional[List[torch.FloatTensor]] = None,
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
decoder_inputs_embeds=None,
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.Tensor] = None,
|
labels: Optional[torch.Tensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
|
|||||||
@@ -462,16 +462,12 @@ RAG_FORWARD_INPUTS_DOCSTRING = r"""
|
|||||||
`question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information.
|
`question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information.
|
||||||
context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
|
context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
|
||||||
Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
|
Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
|
||||||
retriever.
|
retriever. If the model was not initialized with a `retriever` ``context_input_ids` has to be provided to
|
||||||
|
the forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
|
||||||
If the model has is not initialized with a `retriever` ``context_input_ids` has to be provided to the
|
context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`,*optional*, returned when *output_retrieved=True*):
|
||||||
forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`]. context_attention_mask
|
Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
|
||||||
(`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*,
|
retriever. If the model has is not initialized with a `retriever` `context_attention_mask` has to be
|
||||||
returned when *output_retrieved=True*): Attention mask post-processed from the retrieved documents and the
|
provided to the forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
|
||||||
question encoder `input_ids` by the retriever.
|
|
||||||
|
|
||||||
If the model has is not initialized with a `retriever` `context_attention_mask` has to be provided to the
|
|
||||||
forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
|
|
||||||
use_cache (`bool`, *optional*, defaults to `True`):
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||||
`past_key_values`).
|
`past_key_values`).
|
||||||
@@ -545,7 +541,7 @@ class RagModel(RagPreTrainedModel):
|
|||||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
doc_scores: Optional[torch.FloatTensor] = None,
|
doc_scores: Optional[torch.FloatTensor] = None,
|
||||||
context_input_ids: Optional[torch.LongTensor] = None,
|
context_input_ids: Optional[torch.LongTensor] = None,
|
||||||
context_attention_mask=None,
|
context_attention_mask: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
|||||||
@@ -1296,7 +1296,7 @@ class SamModel(SamPreTrainedModel):
|
|||||||
target_embedding: Optional[torch.FloatTensor] = None,
|
target_embedding: Optional[torch.FloatTensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict=None,
|
return_dict: Optional[bool] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Dict[str, torch.Tensor]]:
|
) -> List[Dict[str, torch.Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -903,7 +903,7 @@ class Swin2SRModel(Swin2SRPreTrainedModel):
|
|||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values,
|
pixel_values: torch.FloatTensor,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
|||||||
Reference in New Issue
Block a user