New BartModel (#2745)

* Results same as fairseq
* Wrote a ton of tests
* Struggled with api signatures
* added some docs
This commit is contained in:
Sam Shleifer
2020-02-20 18:11:13 -05:00
committed by GitHub
parent 564fd75d65
commit 53ce3854a1
20 changed files with 1766 additions and 59 deletions

View File

@@ -1448,6 +1448,20 @@ class SequenceSummary(nn.Module):
return output
def create_position_ids_from_input_ids(input_ids, padding_idx):
""" Replace non-padding symbols with their position numbers. Position numbers begin at
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
`utils.make_positions`.
:param torch.Tensor x:
:return torch.Tensor:
"""
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
mask = input_ids.ne(padding_idx).int()
incremental_indicies = torch.cumsum(mask, dim=1).type_as(mask) * mask
return incremental_indicies.long() + padding_idx
def prune_linear_layer(layer, index, dim=0):
""" Prune a linear layer (a model parameters) to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True.