From e19b978151419fe0756ba852b145fccfc96dbeb4 Mon Sep 17 00:00:00 2001 From: Bijay Gurung Date: Sat, 23 May 2020 04:55:22 +0545 Subject: [PATCH] Add Type Hints to modeling_utils.py Closes #3911 (#3948) * Add Type Hints to modeling_utils.py Closes #3911 Add Type Hints to methods in `modeling_utils.py` Note: The coverage isn't 100%. Mostly skipped internal methods. * Reformat according to `black` and `isort` * Use typing.Iterable instead of Sequence * Parameterize Iterable by its generic type * Use typing.Optional when None is the default value * Adhere to style guideline * Update src/transformers/modeling_utils.py * Update src/transformers/modeling_utils.py Co-authored-by: Julien Chaumond --- src/transformers/modeling_utils.py | 72 +++++++++++++++++------------- 1 file changed, 41 insertions(+), 31 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7cd0df9552..84d415d05d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -17,7 +17,7 @@ import inspect import logging import os -from typing import Callable, List, Tuple +from typing import Callable, Dict, Iterable, List, Optional, Tuple import torch from torch import Tensor, device, dtype, nn @@ -164,7 +164,7 @@ class ModuleUtilsMixin: return encoder_extended_attention_mask - def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple, device: device): + def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple, device: device) -> Tensor: """Makes broadcastable attention mask and causal mask so that future and maked tokens are ignored. Arguments: @@ -208,7 +208,7 @@ class ModuleUtilsMixin: extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return extended_attention_mask - def get_head_mask(self, head_mask, num_hidden_layers, is_attention_chunked=False): + def get_head_mask(self, head_mask: Tensor, num_hidden_layers: int, is_attention_chunked: bool = False) -> Tensor: """ # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -302,7 +302,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): else: raise NotImplementedError - def set_input_embeddings(self, value): + def set_input_embeddings(self, value: nn.Module): """ Set model's input embeddings @@ -354,7 +354,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"): output_embeddings.out_features = input_embeddings.num_embeddings - def resize_token_embeddings(self, new_num_tokens=None): + def resize_token_embeddings(self, new_num_tokens: Optional[int] = None): """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size. Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. @@ -387,18 +387,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): self.set_input_embeddings(new_embeddings) return self.get_input_embeddings() - def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None): + def _get_resized_embeddings( + self, old_embeddings: torch.nn.Embedding, new_num_tokens: Optional[int] = None + ) -> torch.nn.Embedding: """ Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly initialized vectors at the end Reducing the size will remove vectors from the end Args: + old_embeddings: ``torch.nn.Embedding`` + Old embeddings to be resized. new_num_tokens: (`optional`) int New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end Reducing the size will remove vectors from the end If not provided or None: return the provided token Embedding Module. - Return: ``torch.nn.Embeddings`` + Return: ``torch.nn.Embedding`` Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None """ if new_num_tokens is None: @@ -433,7 +437,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): # Tie weights if needed self.tie_weights() - def prune_heads(self, heads_to_prune): + def prune_heads(self, heads_to_prune: Dict): """ Prunes heads of the base model. Arguments: @@ -801,28 +805,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): @torch.no_grad() def generate( self, - input_ids=None, - max_length=None, - min_length=None, - do_sample=None, - early_stopping=None, - num_beams=None, - temperature=None, - top_k=None, - top_p=None, - repetition_penalty=None, - bad_words_ids=None, - bos_token_id=None, - pad_token_id=None, - eos_token_id=None, - length_penalty=None, - no_repeat_ngram_size=None, - num_return_sequences=None, - attention_mask=None, - decoder_start_token_id=None, - use_cache=None, + input_ids: Optional[torch.LongTensor] = None, + max_length: Optional[int] = None, + min_length: Optional[int] = None, + do_sample: Optional[bool] = None, + early_stopping: Optional[bool] = None, + num_beams: Optional[int] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + repetition_penalty: Optional[float] = None, + bad_words_ids: Optional[Iterable[int]] = None, + bos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + length_penalty: Optional[float] = None, + no_repeat_ngram_size: Optional[int] = None, + num_return_sequences: Optional[int] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_start_token_id: Optional[int] = None, + use_cache: Optional[bool] = None, **model_specific_kwargs - ): + ) -> torch.LongTensor: r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. Adapted in part from `Facebook's XLM beam search code`_. @@ -1606,7 +1610,7 @@ def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_n return banned_tokens -def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids): +def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]: banned_tokens = [] def _tokens_match(prev_tokens, tokens): @@ -1642,7 +1646,13 @@ def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids): return banned_tokens -def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1): +def top_k_top_p_filtering( + logits: Tensor, + top_k: int = 0, + top_p: float = 1.0, + filter_value: float = -float("Inf"), + min_tokens_to_keep: int = 1, +) -> Tensor: """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (batch size, vocabulary size)