Prophetnet optimization (#9453)
* Vectorized `ngram_attention_bias` calculation * updated formatting with black * Further optimization * one (last) optimization
This commit is contained in:
@@ -171,13 +171,15 @@ def ngram_attention_bias(sequence_length, ngram, device, dtype):
|
|||||||
"""
|
"""
|
||||||
This function computes the bias for the predict stream
|
This function computes the bias for the predict stream
|
||||||
"""
|
"""
|
||||||
bias = torch.ones((ngram, sequence_length, 2 * sequence_length), device=device, dtype=dtype) * float("-inf")
|
left_block = torch.ones((ngram, sequence_length, sequence_length), device=device, dtype=dtype) * float("-inf")
|
||||||
|
right_block = left_block.detach().clone()
|
||||||
# create bias
|
# create bias
|
||||||
for stream_idx in range(ngram):
|
for stream_idx in range(ngram):
|
||||||
for i in range(sequence_length):
|
right_block[stream_idx].fill_diagonal_(0, wrap=False)
|
||||||
bias[stream_idx, i, sequence_length + i] = 0
|
left_block[stream_idx].triu_(-stream_idx + 1)
|
||||||
bias[stream_idx, i, : max(i - stream_idx, 0) + 1] = 0
|
|
||||||
return bias
|
left_block[:, :, 0] = 0
|
||||||
|
return torch.cat([left_block, right_block], dim=2)
|
||||||
|
|
||||||
|
|
||||||
def compute_relative_buckets(num_buckets, max_distance, relative_positions, is_bidirectional=False):
|
def compute_relative_buckets(num_buckets, max_distance, relative_positions, is_bidirectional=False):
|
||||||
|
|||||||
Reference in New Issue
Block a user