Black 20 release

This commit is contained in:
Lysandre
2020-08-26 17:20:22 +02:00
parent e78c110338
commit a75c64d80c
191 changed files with 4807 additions and 3503 deletions

View File

@@ -51,8 +51,7 @@ try:
except ImportError:
# Older PyTorch compatibility
class Identity(nn.Module):
r"""A placeholder identity operator that is argument-insensitive.
"""
r"""A placeholder identity operator that is argument-insensitive."""
def __init__(self, *args, **kwargs):
super().__init__()
@@ -488,8 +487,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
)
def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
""" Tie or clone module weights depending of whether we are using TorchScript or not
"""
"""Tie or clone module weights depending of whether we are using TorchScript or not"""
if self.config.torchscript:
output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
else:
@@ -498,7 +496,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
if getattr(output_embeddings, "bias", None) is not None:
output_embeddings.bias.data = torch.nn.functional.pad(
output_embeddings.bias.data,
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],),
(
0,
output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],
),
"constant",
0,
)
@@ -906,7 +907,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
def load(module: nn.Module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs,
state_dict,
prefix,
local_metadata,
True,
missing_keys,
unexpected_keys,
error_msgs,
)
for name, child in module._modules.items():
if child is not None:
@@ -1242,24 +1249,24 @@ class SQuADHead(nn.Module):
return_dict: bool = False,
) -> Union[SquadHeadOutput, Tuple[torch.FloatTensor]]:
"""
Args:
hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`):
Final hidden states of the model on the sequence tokens.
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Positions of the first token for the labeled span.
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Positions of the last token for the labeled span.
cls_index (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Position of the CLS token for each sentence in the batch. If :obj:`None`, takes the last token.
is_impossible (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Whether the question has a possible answer in the paragraph or not.
p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`):
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS).
1.0 means token should be masked.
return_dict (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return a :class:`~transformers.file_utils.ModelOuput` instead of a plain tuple.
Args:
hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`):
Final hidden states of the model on the sequence tokens.
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Positions of the first token for the labeled span.
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Positions of the last token for the labeled span.
cls_index (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Position of the CLS token for each sentence in the batch. If :obj:`None`, takes the last token.
is_impossible (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Whether the question has a possible answer in the paragraph or not.
p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`):
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS).
1.0 means token should be masked.
return_dict (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return a :class:`~transformers.file_utils.ModelOuput` instead of a plain tuple.
Returns:
Returns:
"""
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
@@ -1375,7 +1382,7 @@ class SequenceSummary(nn.Module):
self.summary = nn.Linear(config.hidden_size, num_classes)
activation_string = getattr(config, "summary_activation", None)
self.activation: Callable = (get_activation(activation_string) if activation_string else Identity())
self.activation: Callable = get_activation(activation_string) if activation_string else Identity()
self.first_dropout = Identity()
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
@@ -1409,7 +1416,11 @@ class SequenceSummary(nn.Module):
output = hidden_states.mean(dim=1)
elif self.summary_type == "cls_index":
if cls_index is None:
cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2] - 1, dtype=torch.long,)
cls_index = torch.full_like(
hidden_states[..., :1, :],
hidden_states.shape[-2] - 1,
dtype=torch.long,
)
else:
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))