[bart] add config.extra_pos_embeddings to facilitate reuse (#5190)
This commit is contained in:
@@ -2090,20 +2090,6 @@ class SequenceSummary(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
def create_position_ids_from_input_ids(input_ids, padding_idx):
|
||||
""" Replace non-padding symbols with their position numbers. Position numbers begin at
|
||||
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
|
||||
`utils.make_positions`.
|
||||
|
||||
:param torch.Tensor x:
|
||||
:return torch.Tensor:
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
||||
mask = input_ids.ne(padding_idx).int()
|
||||
incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
|
||||
return incremental_indices.long() + padding_idx
|
||||
|
||||
|
||||
def prune_linear_layer(layer, index, dim=0):
|
||||
""" Prune a linear layer (a model parameters) to keep only entries in index.
|
||||
Return the pruned layer as a new layer with requires_grad=True.
|
||||
|
||||
Reference in New Issue
Block a user