Black 20 release
This commit is contained in:
@@ -51,8 +51,7 @@ try:
|
||||
except ImportError:
|
||||
# Older PyTorch compatibility
|
||||
class Identity(nn.Module):
|
||||
r"""A placeholder identity operator that is argument-insensitive.
|
||||
"""
|
||||
r"""A placeholder identity operator that is argument-insensitive."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
@@ -488,8 +487,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
)
|
||||
|
||||
def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
|
||||
""" Tie or clone module weights depending of whether we are using TorchScript or not
|
||||
"""
|
||||
"""Tie or clone module weights depending of whether we are using TorchScript or not"""
|
||||
if self.config.torchscript:
|
||||
output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
|
||||
else:
|
||||
@@ -498,7 +496,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
if getattr(output_embeddings, "bias", None) is not None:
|
||||
output_embeddings.bias.data = torch.nn.functional.pad(
|
||||
output_embeddings.bias.data,
|
||||
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],),
|
||||
(
|
||||
0,
|
||||
output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],
|
||||
),
|
||||
"constant",
|
||||
0,
|
||||
)
|
||||
@@ -906,7 +907,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
def load(module: nn.Module, prefix=""):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
module._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
True,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
@@ -1242,24 +1249,24 @@ class SQuADHead(nn.Module):
|
||||
return_dict: bool = False,
|
||||
) -> Union[SquadHeadOutput, Tuple[torch.FloatTensor]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`):
|
||||
Final hidden states of the model on the sequence tokens.
|
||||
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Positions of the first token for the labeled span.
|
||||
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Positions of the last token for the labeled span.
|
||||
cls_index (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Position of the CLS token for each sentence in the batch. If :obj:`None`, takes the last token.
|
||||
is_impossible (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Whether the question has a possible answer in the paragraph or not.
|
||||
p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`):
|
||||
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS).
|
||||
1.0 means token should be masked.
|
||||
return_dict (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOuput` instead of a plain tuple.
|
||||
Args:
|
||||
hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len, hidden_size)`):
|
||||
Final hidden states of the model on the sequence tokens.
|
||||
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Positions of the first token for the labeled span.
|
||||
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Positions of the last token for the labeled span.
|
||||
cls_index (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Position of the CLS token for each sentence in the batch. If :obj:`None`, takes the last token.
|
||||
is_impossible (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Whether the question has a possible answer in the paragraph or not.
|
||||
p_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, seq_len)`, `optional`):
|
||||
Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS).
|
||||
1.0 means token should be masked.
|
||||
return_dict (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to return a :class:`~transformers.file_utils.ModelOuput` instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
Returns:
|
||||
"""
|
||||
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
|
||||
|
||||
@@ -1375,7 +1382,7 @@ class SequenceSummary(nn.Module):
|
||||
self.summary = nn.Linear(config.hidden_size, num_classes)
|
||||
|
||||
activation_string = getattr(config, "summary_activation", None)
|
||||
self.activation: Callable = (get_activation(activation_string) if activation_string else Identity())
|
||||
self.activation: Callable = get_activation(activation_string) if activation_string else Identity()
|
||||
|
||||
self.first_dropout = Identity()
|
||||
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
|
||||
@@ -1409,7 +1416,11 @@ class SequenceSummary(nn.Module):
|
||||
output = hidden_states.mean(dim=1)
|
||||
elif self.summary_type == "cls_index":
|
||||
if cls_index is None:
|
||||
cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2] - 1, dtype=torch.long,)
|
||||
cls_index = torch.full_like(
|
||||
hidden_states[..., :1, :],
|
||||
hidden_states.shape[-2] - 1,
|
||||
dtype=torch.long,
|
||||
)
|
||||
else:
|
||||
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
|
||||
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
|
||||
|
||||
Reference in New Issue
Block a user