From 390cf16bc84d096979647f494c84dd67576f6166 Mon Sep 17 00:00:00 2001 From: guillaume-be Date: Thu, 7 Jan 2021 11:41:58 +0100 Subject: [PATCH] Prophetnet optimization (#9453) * Vectorized `ngram_attention_bias` calculation * updated formatting with black * Further optimization * one (last) optimization --- .../models/prophetnet/modeling_prophetnet.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 11197182f8..ae53e54ee8 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -171,13 +171,15 @@ def ngram_attention_bias(sequence_length, ngram, device, dtype): """ This function computes the bias for the predict stream """ - bias = torch.ones((ngram, sequence_length, 2 * sequence_length), device=device, dtype=dtype) * float("-inf") + left_block = torch.ones((ngram, sequence_length, sequence_length), device=device, dtype=dtype) * float("-inf") + right_block = left_block.detach().clone() # create bias for stream_idx in range(ngram): - for i in range(sequence_length): - bias[stream_idx, i, sequence_length + i] = 0 - bias[stream_idx, i, : max(i - stream_idx, 0) + 1] = 0 - return bias + right_block[stream_idx].fill_diagonal_(0, wrap=False) + left_block[stream_idx].triu_(-stream_idx + 1) + + left_block[:, :, 0] = 0 + return torch.cat([left_block, right_block], dim=2) def compute_relative_buckets(num_buckets, max_distance, relative_positions, is_bidirectional=False):