Fix typos

This commit is contained in:
sshleifer
2020-04-22 17:50:18 -04:00
committed by Julien Chaumond
parent 12bb7fe770
commit 41750a6cff

View File

@@ -302,7 +302,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
def _tie_or_clone_weights(self, output_embeddings, input_embeddings): def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
""" Tie or clone module weights depending of weither we are using TorchScript or not """ Tie or clone module weights depending of whether we are using TorchScript or not
""" """
if self.config.torchscript: if self.config.torchscript:
output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone()) output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
@@ -1524,7 +1524,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
return decoded return decoded
# force one of token_ids to be generated by setting prob of all other tokens to 0. # force one of token_ids to be generated by setting prob of all other tokens to 0.
def _force_token_ids_generation(self, scores, token_ids): def _force_token_ids_generation(self, scores, token_ids) -> None:
if isinstance(token_ids, int): if isinstance(token_ids, int):
token_ids = [token_ids] token_ids = [token_ids]
all_but_token_ids_mask = torch.tensor( all_but_token_ids_mask = torch.tensor(
@@ -2025,8 +2025,8 @@ def create_position_ids_from_input_ids(input_ids, padding_idx):
""" """
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = input_ids.ne(padding_idx).int() mask = input_ids.ne(padding_idx).int()
incremental_indicies = torch.cumsum(mask, dim=1).type_as(mask) * mask incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
return incremental_indicies.long() + padding_idx return incremental_indices.long() + padding_idx
def prune_linear_layer(layer, index, dim=0): def prune_linear_layer(layer, index, dim=0):