From 6bf885375ad2027a5bdf67cf84055e8eb0f8064f Mon Sep 17 00:00:00 2001 From: Kian Sierra McGettigan <47116198+kiansierra@users.noreply.github.com> Date: Thu, 2 Mar 2023 18:07:45 +0100 Subject: [PATCH] Prophetnet batch dimension inversion fix (#21870) * decoder forward pass is working * no model has forward pass returning attentions * decoder ngram changed to not mix batch size * current basic forward pass returns identical result * passed test_model attentions * passed test_encoder_decoder_model_generate * passed test_headmasking * removed old block * removed comments bug/fixme * removed bug comments * applied styling * applied fix-copies * applied ngram forward comments * corrected dimension notation * applied styling and comment fixes * changed asserts for raise ValueError * changed question gen test * updated hidden_states integration test * applied styling --- .../models/prophetnet/modeling_prophetnet.py | 286 ++++++++---------- .../xlm_prophetnet/modeling_xlm_prophetnet.py | 286 ++++++++---------- .../prophetnet/test_modeling_prophetnet.py | 4 +- 3 files changed, 262 insertions(+), 314 deletions(-) diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 61b6d943af..56a89d88f8 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -701,44 +701,27 @@ class ProphetNetAttention(nn.Module): past_key_value = (key_states, value_states) # project states into the correct shape - proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim) + proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, batch_size).view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - assert attn_weights.size() == ( - batch_size * self.num_attn_heads, - tgt_len, - src_len, - ), ( - f"`attn_weights` should be of size {batch_size * self.num_attn_heads, tgt_len, src_len}, but is of size" - f" {attn_weights.shape}" - ) + src_len = key_states.size(2) + attn_weights = torch.einsum("bsij,bsjk->bsik", query_states, key_states.transpose(2, 3)) + expected_shape = (batch_size, self.num_attn_heads, tgt_len, src_len) + if attn_weights.size() != expected_shape: + raise ValueError(f"Attention weights should have size {expected_shape}, but is {attn_weights.size()}") # This is part of a workaround to get around fork/join parallelism not supporting Optional types. if attention_mask is not None and attention_mask.dim() == 0: attention_mask = None - assert attention_mask is None or attention_mask.size() == ( - self.num_attn_heads * batch_size, - 1, - src_len, - ), ( - "`attention_mask` should be `None` or of shape attention_mask.size() ==" - f" {batch_size * self.num_attn_heads, 1, src_len}, but is {attention_mask.shape}" - ) + expected_shape = (batch_size, self.num_attn_heads, 1, src_len) + if attention_mask is not None and attention_mask.size() != expected_shape: + raise ValueError(f"Attention mask should have size {expected_shape}, but is {attention_mask.size()}") if attention_mask is not None: # don't attend to padding symbols attn_weights = attn_weights + attention_mask - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(batch_size, self.num_attn_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(batch_size * self.num_attn_heads, tgt_len, src_len) + attn_weights_reshaped = attn_weights else: attn_weights_reshaped = None @@ -752,7 +735,6 @@ class ProphetNetAttention(nn.Module): attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( batch_size, self.num_attn_heads, tgt_len, src_len ) - attn_weights = attn_weights.view(batch_size * self.num_attn_heads, tgt_len, src_len) # apply head_mask also on attn_weights_reshaped which is used for n-gram attention inside the model attn_weights_reshaped = layer_head_mask.view(1, -1, 1, 1) * attn_weights_reshaped @@ -762,23 +744,12 @@ class ProphetNetAttention(nn.Module): p=self.attention_dropout, training=self.training, ) + attn_output = torch.einsum("bsij,bsjk->bsik", attn_probs, value_states) + expected_shape = (batch_size, self.num_attn_heads, tgt_len, self.head_dim) + if attn_output.size() != expected_shape: + raise ValueError(f"`attn_output` should have shape {expected_shape}, but is of shape {attn_output.size()}") - attn_output = torch.bmm(attn_probs, value_states) - assert attn_output.size() == ( - batch_size * self.num_attn_heads, - tgt_len, - self.head_dim, - ), ( - f"`attn_output` should be of shape {batch_size * self.num_attn_heads, tgt_len, self.head_dim}, but is of" - f" shape {attn_output.size()}" - ) - - attn_output = ( - attn_output.view(batch_size, self.num_attn_heads, tgt_len, self.head_dim) - .transpose(1, 2) - .reshape(batch_size, tgt_len, hidden_size) - ) - + attn_output = attn_output.transpose(1, 2).reshape(batch_size, tgt_len, hidden_size) attn_output = self.out_proj(attn_output) attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training) @@ -856,7 +827,6 @@ class ProphetNetNgramSelfAttention(nn.Module): position_ids=None, ): batch_size, ngram_sequence_length, hidden_size = hidden_states.size() - assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], ( f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape" f" {hidden_states.shape}" @@ -874,8 +844,7 @@ class ProphetNetNgramSelfAttention(nn.Module): query_states = self._shape(query_states, ngram_sequence_length, batch_size) key_states = self._shape(key_states, -1, batch_size) value_states = self._shape(value_states, -1, batch_size) - - proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim) + proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim) query_states = query_states.view(*proj_shape) key_states = key_states.view(*proj_shape) @@ -883,10 +852,9 @@ class ProphetNetNgramSelfAttention(nn.Module): # chunk into main stream and predict stream hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=1) - - query_states_list = query_states.chunk(1 + self.ngram, dim=1) - key_states_list = key_states.chunk(1 + self.ngram, dim=1) - value_states_list = value_states.chunk(1 + self.ngram, dim=1) + query_states_list = query_states.chunk(1 + self.ngram, dim=2) + key_states_list = key_states.chunk(1 + self.ngram, dim=2) + value_states_list = value_states.chunk(1 + self.ngram, dim=2) main_hidden_states, hidden_states_predict_list = hidden_states_list[0], hidden_states_list[1:] main_query_states, predict_query_states_list = query_states_list[0], query_states_list[1:] @@ -895,28 +863,29 @@ class ProphetNetNgramSelfAttention(nn.Module): # saved states are stored with shape (batch_size, num_attn_heads, seq_len, head_dim) if past_key_value is not None: - prev_main_key_states = past_key_value[0].view(batch_size * self.num_attn_heads, -1, self.head_dim) - main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=1) - prev_main_value_states = past_key_value[1].view(batch_size * self.num_attn_heads, -1, self.head_dim) - main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=1) + prev_main_key_states = past_key_value[0] + main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=2) + prev_main_value_states = past_key_value[1] + main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=2) # Update cache - past_key_value = ( - main_key_states.view(batch_size, self.num_attn_heads, -1, self.head_dim), - main_value_states.view(batch_size, self.num_attn_heads, -1, self.head_dim), - ) + past_key_value = (main_key_states, main_value_states) # get seq_length of main stream only sequence_length = ngram_sequence_length // (1 + self.ngram) # MAIN-STREAM # main attn weights - main_attn_weights = torch.bmm(main_query_states, main_key_states.transpose(1, 2)) + # [batch_size, number_heads, sequence_length, head_dimesion] + # x [batch_size, number_heads, head_dimesion, sequence_length] + # -> [batch_size, number_heads, sequence_length, sequence_length] + main_attn_weights = torch.einsum("bntc,bncs->bnts", main_query_states, main_key_states.transpose(2, 3)) # retrieve relative position embeddings for each layer -> see paper for more details main_relative_pos_embeddings = self.get_main_relative_pos_embeddings( main_hidden_states, main_attn_weights, position_ids, main_relative_position_buckets ) + main_attn_weights = main_attn_weights + main_relative_pos_embeddings if attention_mask is not None: @@ -936,55 +905,53 @@ class ProphetNetNgramSelfAttention(nn.Module): main_attn_probs = layer_head_mask.view(1, -1, 1, 1) * main_attn_probs.view( batch_size, self.num_attn_heads, -1, sequence_length ) - main_attn_probs = main_attn_probs.view(batch_size * self.num_attn_heads, -1, sequence_length) main_attn_probs = nn.functional.dropout(main_attn_probs, p=self.attention_dropout, training=self.training) # project to attn_output - main_attn_output = torch.bmm(main_attn_probs, main_value_states) - + # [batch_size, number_heads, sequence_length, sequence_length] + # x [batch_size, number_heads, sequence_length, head_dimesion] + # -> [batch_size, number_heads, sequence_length, head_dimesion] + main_attn_output = torch.einsum("bntc,bncs->bnts", main_attn_probs, main_value_states) # reshape so that num_heads dim is merged into last `head_dim` axis - main_attn_output = ( - main_attn_output.view(batch_size, self.num_attn_heads, sequence_length, self.head_dim) - .transpose(1, 2) - .reshape(batch_size, 1, sequence_length, hidden_size) - ) + main_attn_output = main_attn_output.transpose(1, 2).reshape(batch_size, 1, sequence_length, hidden_size) main_attn_output = self.out_proj(main_attn_output) # PREDICT-STREAM - # [ngram, B*head, T, c] - predict_query_states = torch.cat(predict_query_states_list, 0).view( - self.ngram, -1, sequence_length, self.head_dim - ) - # [ngram, B*head, 2*T, c] - predict_key_states = torch.cat( - [torch.cat([main_key_states, key], 1).unsqueeze(0) for key in predict_key_states_list], 0 + # [batch_size, ngram, number_heads, sequence_length, head_dimesion] + predict_query_states = torch.stack(predict_query_states_list, 1).view( + batch_size, self.ngram, self.num_attn_heads, sequence_length, self.head_dim ) - # [ngram, T, B, C] - predict_hidden_states = torch.cat(hidden_states_predict_list, 0).view( - self.ngram, sequence_length, batch_size, hidden_size - ) + # [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion] + predict_key_states = torch.stack([torch.cat([main_key_states, key], 2) for key in predict_key_states_list], 1) - # [ngram, B*head, 2*T, c] + # [batch_size, sequence_length, ngram, hidden_size] + predict_hidden_states = torch.stack(hidden_states_predict_list, dim=2) + + # [batch_size, number_heads, ngram, 2*sequence_length, head_dimesion] predict_value_states = torch.cat( - [torch.cat([main_value_states, v_p], 1).unsqueeze(0) for v_p in predict_value_states_list], 0 + [torch.cat([main_value_states, v_p], 2).unsqueeze(2) for v_p in predict_value_states_list], 2 ) - # [ngram, B*head, T, 2*T] - predict_attn_weights = torch.einsum("nbtc,nbsc->nbts", (predict_query_states, predict_key_states)) - # [ngram, B*head, T, S] + # [batch_size, ngram, number_heads, sequence_length, head_dimesion] + # x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion] + # -> [batch_size, ngram, number_heads, sequence_length, 2*sequence_length] + predict_attn_weights = torch.einsum("bnhtc,bnhsc->bnhts", (predict_query_states, predict_key_states)) + # retrieve relative position embeddings for each layer -> see paper for more details + # [batch_size, ngram, number_heads, sequence_length, predict_relative_pos_embeddings] predict_relative_pos_embeddings = self.get_predict_relative_pos_embeddings( predict_hidden_states, predict_attn_weights, position_ids, predict_relative_position_buckets ) - # [ngram, B*head, T, 2*T] + # [batch_size, ngram, number_heads, sequence_length, 2*sequence_length] predict_attn_weights = predict_attn_weights + predict_relative_pos_embeddings if extended_predict_attention_mask is not None: - predict_attn_weights = predict_attn_weights + extended_predict_attention_mask.to( - predict_attn_weights.dtype - ) + # Permuting Predict attention mask to [batch_size, ngram, number_heads, sequence_length, 2*sequence_length] + extended_predict_attention_mask = extended_predict_attention_mask.permute(0, 2, 1, 3, 4) + extended_predict_attention_mask = extended_predict_attention_mask.to(predict_attn_weights.dtype) + predict_attn_weights = predict_attn_weights + extended_predict_attention_mask predict_attn_probs = softmax( predict_attn_weights, @@ -997,37 +964,30 @@ class ProphetNetNgramSelfAttention(nn.Module): f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is" f" {layer_head_mask.size()}" ) - predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs.view( - self.ngram, batch_size, self.num_attn_heads, sequence_length, 2 * sequence_length - ) - predict_attn_probs = predict_attn_probs.view( - self.ngram, batch_size * self.num_attn_heads, sequence_length, 2 * sequence_length - ) + predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs predict_attn_probs = nn.functional.dropout( predict_attn_probs, p=self.attention_dropout, training=self.training ) # project to attention output - # [ngram, B*head, T, c] - predict_attn_output = torch.einsum("nbts,nbsc->nbtc", (predict_attn_probs, predict_value_states)) + # [batch_size, ngram, number_heads, sequence_length, 2*sequence_length] + # x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion] + # -> [batch_size, ngram, number_heads, sequence_length, head_dimesion] + predict_attn_output = torch.einsum( + "bnhts,bnhsc->bnhtc", (predict_attn_probs, predict_value_states.transpose(1, 2)) + ) # reshape so that num_heads dim is merged into last `head_dim` axis - # [ngram, B, T, C] - predict_attn_output = ( - predict_attn_output.view(self.ngram, batch_size, self.num_attn_heads, sequence_length, self.head_dim) - .permute(1, 0, 3, 2, 4) - .reshape(batch_size, self.ngram, sequence_length, hidden_size) - ) + # [batch_size, ngram, number_heads, sequence_length, head_dimesion] -> [batch_size, ngram, sequence_length, hidden_size] + predict_attn_output = predict_attn_output.transpose(2, 3) + predict_attn_output = predict_attn_output.reshape(batch_size, self.ngram, sequence_length, hidden_size) predict_attn_output = self.out_proj(predict_attn_output) # concat to single attn output - # [B, 1+ngram*T, C] + # [batch_size, (1+ngram)*sequence_length, hidden_size] attn_output = torch.cat([main_attn_output, predict_attn_output], 1).view(batch_size, -1, hidden_size) # reshape into better form for `config.output_attentions` main_attn_probs = main_attn_probs.view(batch_size, self.num_attn_heads, sequence_length, -1) - predict_attn_probs = predict_attn_probs.view( - self.ngram, batch_size, self.num_attn_heads, sequence_length, -1 - ).transpose(0, 1) attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training) @@ -1036,8 +996,11 @@ class ProphetNetNgramSelfAttention(nn.Module): def get_main_relative_pos_embeddings( self, hidden_states, attn_weights, position_ids, main_relative_position_buckets ): - # input hidden_states [B,T,C], input attn_weights [T*head,T,S], input position_ids [B,T] or [1,1] - + # input hidden_states [batch_size, sequence_length, hidden_size] + # input attn_weights [batch_size, num_heads, sequence_length, sequence_length] + # input position_ids [batch_size, sequence_length] or [1,1] + batch_size, num_attn_heads, tgt_len, src_len = attn_weights.shape + attn_weights = attn_weights.view(batch_size, num_attn_heads, tgt_len, src_len) if main_relative_position_buckets is None: batch_size, sequence_length = hidden_states.shape[:2] relative_positions = ( @@ -1047,39 +1010,42 @@ class ProphetNetNgramSelfAttention(nn.Module): .repeat(batch_size, sequence_length, 1) .to(position_ids.device) ) - relative_positions = relative_positions - position_ids.unsqueeze(0).repeat( - batch_size, sequence_length, 1 - ) # [B, T, s] + # [batch_size, sequence_length, sequence_length+1] + relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(batch_size, sequence_length, 1) main_relative_position_buckets = compute_relative_buckets( self.num_buckets, self.relative_max_distance, relative_positions, False ) - rel_pos_embeddings = self.relative_pos_embeddings(hidden_states) # [B,T,Buckets*head] + # [batch_size, sequence_length, num_buckets * num_heads] + rel_pos_embeddings = self.relative_pos_embeddings(hidden_states) rel_pos_embeddings = rel_pos_embeddings.view( rel_pos_embeddings.shape[:2] + (self.num_buckets, self.num_attn_heads) - ).permute( - 0, 3, 1, 2 - ) # [B,T,Buckets,head] - rel_pos_embeddings = rel_pos_embeddings.reshape(attn_weights.shape[:2] + (-1,)) # [B*head,T,Buckets] + ) + rel_pos_embeddings = rel_pos_embeddings.permute(0, 3, 1, 2) + # [batch_size, num_heads, sequence_length, num_buckets] + rel_pos_embeddings = rel_pos_embeddings.reshape(attn_weights.shape[:3] + (-1,)) - main_relative_position_buckets = ( - main_relative_position_buckets.repeat(1, self.num_attn_heads, 1) - .view(-1, main_relative_position_buckets.shape[-1]) - .long() - ) # [B*head*T, T] - rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1)) # [B*head*T,Buckets] - - main_relative_pos_embeddings = torch.gather( - rel_pos_embeddings, dim=1, index=main_relative_position_buckets - ).view(attn_weights.shape[:2] + (-1,)) + main_relative_position_buckets = main_relative_position_buckets.repeat(1, self.num_attn_heads, 1) + # [batch_size * num_heads * sequence_length, sequence_length] + main_relative_position_buckets = main_relative_position_buckets.view( + -1, main_relative_position_buckets.shape[-1] + ) + main_relative_position_buckets = main_relative_position_buckets.long() + # [batch_size * num_heads * sequence_length, sequence_length] + rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1)) + main_relative_pos_embeddings = torch.gather(rel_pos_embeddings, dim=1, index=main_relative_position_buckets) + main_relative_pos_embeddings = main_relative_pos_embeddings.view(batch_size, num_attn_heads, tgt_len, -1) return main_relative_pos_embeddings def get_predict_relative_pos_embeddings( self, hidden_states, attn_weights, position_ids, predict_relative_position_buckets ): - # input hidden_states [ngram, T,B,C], input attn_weights [ngram, B*head,T,S], input position_ids [B,T] or [1,1], input predict_relative_position_buckets [B,T, 2*T] or None - sequence_length, batch_size = hidden_states.shape[1:3] + # input hidden_states [batch_size, sequence_length, ngram, hidden_size] + # input attn_weights [batch_size, ngram, num_heads, sequence_length, 2*sequence_length] + # input position_ids [batch_size, sequence_length] or [1,1] + # input predict_relative_position_buckets [batch_size, sequence_length, 2*sequence_length] or None + batch_size, sequence_length = hidden_states.shape[0:2] if predict_relative_position_buckets is None: key_sequence_length = attn_weights.shape[-1] @@ -1099,28 +1065,35 @@ class ProphetNetNgramSelfAttention(nn.Module): self.num_buckets, self.relative_max_distance, relative_positions, False ) - hidden_states = hidden_states.transpose(1, 2) # [ngram, B, T, C] - rel_pos_embeddings = self.relative_pos_embeddings(hidden_states).view( + # [batch_size, ngram, sequence_length, hidden_size] + hidden_states = hidden_states.transpose(1, 2) + rel_pos_embeddings = self.relative_pos_embeddings(hidden_states) + + # [batch_size, ngram, sequence_length, num_buckets, num_heads] + rel_pos_embeddings = rel_pos_embeddings.view( hidden_states.shape[:-1] + (self.num_buckets, self.num_attn_heads) - ) # [ngram, B, T, bucket, head] - rel_pos_embeddings = rel_pos_embeddings.permute(0, 1, 4, 2, 3).reshape( - self.ngram * batch_size * self.num_attn_heads, sequence_length, -1 - ) # [ngram*B*head, T, bucket] - - predict_relative_position_buckets = predict_relative_position_buckets.unsqueeze(0).repeat( + ) + rel_pos_embeddings = rel_pos_embeddings.permute(0, 2, 1, 4, 3) + # [batch_size * ngram * sequence_length * num_heads, num_buckets] + rel_pos_embeddings = rel_pos_embeddings.reshape(-1, self.num_buckets) + # [ngram, batch_size, num_heads * sequence_length, -1] + predict_relative_position_buckets = predict_relative_position_buckets.unsqueeze(0) + predict_relative_position_buckets = predict_relative_position_buckets.repeat( self.ngram, 1, self.num_attn_heads, 1 - ) # [ngram, B, head*T, S] - - rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1)) + ) + # [ngram * batch_size * num_heads * sequence_length, -1] predict_relative_position_buckets = predict_relative_position_buckets.view( -1, predict_relative_position_buckets.size(-1) - ).long() # [ngram*B*head*T, S] + ).long() predict_relative_pos_embeddings = torch.gather( rel_pos_embeddings, dim=1, index=predict_relative_position_buckets - ).view( - self.ngram, batch_size * self.num_attn_heads, sequence_length, -1 - ) # [ngram, B*head, T, S] + ) + + # [batch_size, gram, num_heads, sequence_length, -1] + predict_relative_pos_embeddings = predict_relative_pos_embeddings.view( + batch_size, self.ngram, self.num_attn_heads, sequence_length, -1 + ) return predict_relative_pos_embeddings @@ -1331,7 +1304,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel): # prepare attention mask if attention_mask is not None: extended_attention_mask = ( - 1.0 - attention_mask[:, None, :].repeat(self.config.num_encoder_attention_heads, 1, 1) + 1.0 - attention_mask[:, None, None, :].repeat(1, self.config.num_encoder_attention_heads, 1, 1) ) * torch.finfo(self.dtype).min extended_attention_mask = extended_attention_mask.to(inputs_embeds.dtype) else: @@ -1549,7 +1522,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): # prepare encoder attention mask if encoder_attention_mask is not None: extended_encoder_attention_mask = ( - 1.0 - encoder_attention_mask[:, None, :].repeat(self.config.num_decoder_attention_heads, 1, 1) + 1.0 - encoder_attention_mask[:, None, None, :].repeat(1, self.config.num_decoder_attention_heads, 1, 1) ) * torch.finfo(self.dtype).min extended_encoder_attention_mask = extended_encoder_attention_mask.to(inputs_embeds.dtype) else: @@ -1717,17 +1690,18 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): device=hidden_states.device, ) causal_mask = torch.triu(causal_mask, 1) - extended_causal_mask = causal_mask[:seq_length, :seq_length][None, :, :].expand( - (batch_size,) + causal_mask.shape + + extended_causal_mask = causal_mask[:seq_length, :seq_length][None, None, :, :].expand( + (batch_size, self.config.num_decoder_attention_heads) + causal_mask.shape ) # add usual attention mask if attention_mask is not None: - extended_attention_mask = (1.0 - attention_mask[:, None, :]) * torch.finfo(self.dtype).min + extended_attention_mask = (1.0 - attention_mask[:, None, None, :]) * torch.finfo(self.dtype).min extended_attention_mask = extended_causal_mask + extended_attention_mask else: extended_attention_mask = extended_causal_mask - return extended_attention_mask.repeat(self.config.num_decoder_attention_heads, 1, 1).to(hidden_states.dtype) + return extended_attention_mask.to(hidden_states.dtype) def prepare_predict_attention_mask(self, hidden_states, attention_mask): batch_size, seq_length = hidden_states.shape[:2] @@ -1745,14 +1719,16 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ], dim=-1, ) - extended_predict_causal_mask = predict_causal_mask[:, None, :, :].expand( - predict_causal_mask.shape[:1] + (batch_size,) + predict_causal_mask.shape[1:] + extended_predict_causal_mask = predict_causal_mask[None, None, :, :, :].expand( + (batch_size, self.config.num_decoder_attention_heads) + predict_causal_mask.shape ) # add usual attention mask if attention_mask is not None: - extended_attention_mask = (1.0 - attention_mask[None, :, None, :]) * torch.finfo(self.dtype).min - extended_attention_mask = extended_attention_mask.expand((self.ngram, batch_size, seq_length, seq_length)) + extended_attention_mask = (1.0 - attention_mask[:, None, None, None, :]) * torch.finfo(self.dtype).min + extended_attention_mask = extended_attention_mask.expand( + (batch_size, self.config.num_decoder_attention_heads, self.ngram, seq_length, seq_length) + ) # predicted stream attention_mask should always be 0 extended_attention_mask = torch.cat( [extended_attention_mask, torch.zeros_like(extended_attention_mask)], dim=-1 @@ -1760,9 +1736,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): extended_predict_attention_mask = extended_predict_causal_mask + extended_attention_mask else: extended_predict_attention_mask = extended_predict_causal_mask - return extended_predict_attention_mask.repeat(1, self.config.num_decoder_attention_heads, 1, 1).to( - hidden_states.dtype - ) + return extended_predict_attention_mask.to(hidden_states.dtype) @add_start_docstrings( diff --git a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py index e92f5e3906..0e2567f99c 100644 --- a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -716,44 +716,27 @@ class XLMProphetNetAttention(nn.Module): past_key_value = (key_states, value_states) # project states into the correct shape - proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim) + proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, batch_size).view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - assert attn_weights.size() == ( - batch_size * self.num_attn_heads, - tgt_len, - src_len, - ), ( - f"`attn_weights` should be of size {batch_size * self.num_attn_heads, tgt_len, src_len}, but is of size" - f" {attn_weights.shape}" - ) + src_len = key_states.size(2) + attn_weights = torch.einsum("bsij,bsjk->bsik", query_states, key_states.transpose(2, 3)) + expected_shape = (batch_size, self.num_attn_heads, tgt_len, src_len) + if attn_weights.size() != expected_shape: + raise ValueError(f"Attention weights should have size {expected_shape}, but is {attn_weights.size()}") # This is part of a workaround to get around fork/join parallelism not supporting Optional types. if attention_mask is not None and attention_mask.dim() == 0: attention_mask = None - assert attention_mask is None or attention_mask.size() == ( - self.num_attn_heads * batch_size, - 1, - src_len, - ), ( - "`attention_mask` should be `None` or of shape attention_mask.size() ==" - f" {batch_size * self.num_attn_heads, 1, src_len}, but is {attention_mask.shape}" - ) + expected_shape = (batch_size, self.num_attn_heads, 1, src_len) + if attention_mask is not None and attention_mask.size() != expected_shape: + raise ValueError(f"Attention mask should have size {expected_shape}, but is {attention_mask.size()}") if attention_mask is not None: # don't attend to padding symbols attn_weights = attn_weights + attention_mask - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(batch_size, self.num_attn_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(batch_size * self.num_attn_heads, tgt_len, src_len) + attn_weights_reshaped = attn_weights else: attn_weights_reshaped = None @@ -767,7 +750,6 @@ class XLMProphetNetAttention(nn.Module): attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( batch_size, self.num_attn_heads, tgt_len, src_len ) - attn_weights = attn_weights.view(batch_size * self.num_attn_heads, tgt_len, src_len) # apply head_mask also on attn_weights_reshaped which is used for n-gram attention inside the model attn_weights_reshaped = layer_head_mask.view(1, -1, 1, 1) * attn_weights_reshaped @@ -777,23 +759,12 @@ class XLMProphetNetAttention(nn.Module): p=self.attention_dropout, training=self.training, ) + attn_output = torch.einsum("bsij,bsjk->bsik", attn_probs, value_states) + expected_shape = (batch_size, self.num_attn_heads, tgt_len, self.head_dim) + if attn_output.size() != expected_shape: + raise ValueError(f"`attn_output` should have shape {expected_shape}, but is of shape {attn_output.size()}") - attn_output = torch.bmm(attn_probs, value_states) - assert attn_output.size() == ( - batch_size * self.num_attn_heads, - tgt_len, - self.head_dim, - ), ( - f"`attn_output` should be of shape {batch_size * self.num_attn_heads, tgt_len, self.head_dim}, but is of" - f" shape {attn_output.size()}" - ) - - attn_output = ( - attn_output.view(batch_size, self.num_attn_heads, tgt_len, self.head_dim) - .transpose(1, 2) - .reshape(batch_size, tgt_len, hidden_size) - ) - + attn_output = attn_output.transpose(1, 2).reshape(batch_size, tgt_len, hidden_size) attn_output = self.out_proj(attn_output) attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training) @@ -873,7 +844,6 @@ class XLMProphetNetNgramSelfAttention(nn.Module): position_ids=None, ): batch_size, ngram_sequence_length, hidden_size = hidden_states.size() - assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], ( f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape" f" {hidden_states.shape}" @@ -891,8 +861,7 @@ class XLMProphetNetNgramSelfAttention(nn.Module): query_states = self._shape(query_states, ngram_sequence_length, batch_size) key_states = self._shape(key_states, -1, batch_size) value_states = self._shape(value_states, -1, batch_size) - - proj_shape = (batch_size * self.num_attn_heads, -1, self.head_dim) + proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim) query_states = query_states.view(*proj_shape) key_states = key_states.view(*proj_shape) @@ -900,10 +869,9 @@ class XLMProphetNetNgramSelfAttention(nn.Module): # chunk into main stream and predict stream hidden_states_list = hidden_states.chunk(1 + self.ngram, dim=1) - - query_states_list = query_states.chunk(1 + self.ngram, dim=1) - key_states_list = key_states.chunk(1 + self.ngram, dim=1) - value_states_list = value_states.chunk(1 + self.ngram, dim=1) + query_states_list = query_states.chunk(1 + self.ngram, dim=2) + key_states_list = key_states.chunk(1 + self.ngram, dim=2) + value_states_list = value_states.chunk(1 + self.ngram, dim=2) main_hidden_states, hidden_states_predict_list = hidden_states_list[0], hidden_states_list[1:] main_query_states, predict_query_states_list = query_states_list[0], query_states_list[1:] @@ -912,28 +880,29 @@ class XLMProphetNetNgramSelfAttention(nn.Module): # saved states are stored with shape (batch_size, num_attn_heads, seq_len, head_dim) if past_key_value is not None: - prev_main_key_states = past_key_value[0].view(batch_size * self.num_attn_heads, -1, self.head_dim) - main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=1) - prev_main_value_states = past_key_value[1].view(batch_size * self.num_attn_heads, -1, self.head_dim) - main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=1) + prev_main_key_states = past_key_value[0] + main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=2) + prev_main_value_states = past_key_value[1] + main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=2) # Update cache - past_key_value = ( - main_key_states.view(batch_size, self.num_attn_heads, -1, self.head_dim), - main_value_states.view(batch_size, self.num_attn_heads, -1, self.head_dim), - ) + past_key_value = (main_key_states, main_value_states) # get seq_length of main stream only sequence_length = ngram_sequence_length // (1 + self.ngram) # MAIN-STREAM # main attn weights - main_attn_weights = torch.bmm(main_query_states, main_key_states.transpose(1, 2)) + # [batch_size, number_heads, sequence_length, head_dimesion] + # x [batch_size, number_heads, head_dimesion, sequence_length] + # -> [batch_size, number_heads, sequence_length, sequence_length] + main_attn_weights = torch.einsum("bntc,bncs->bnts", main_query_states, main_key_states.transpose(2, 3)) # retrieve relative position embeddings for each layer -> see paper for more details main_relative_pos_embeddings = self.get_main_relative_pos_embeddings( main_hidden_states, main_attn_weights, position_ids, main_relative_position_buckets ) + main_attn_weights = main_attn_weights + main_relative_pos_embeddings if attention_mask is not None: @@ -953,55 +922,53 @@ class XLMProphetNetNgramSelfAttention(nn.Module): main_attn_probs = layer_head_mask.view(1, -1, 1, 1) * main_attn_probs.view( batch_size, self.num_attn_heads, -1, sequence_length ) - main_attn_probs = main_attn_probs.view(batch_size * self.num_attn_heads, -1, sequence_length) main_attn_probs = nn.functional.dropout(main_attn_probs, p=self.attention_dropout, training=self.training) # project to attn_output - main_attn_output = torch.bmm(main_attn_probs, main_value_states) - + # [batch_size, number_heads, sequence_length, sequence_length] + # x [batch_size, number_heads, sequence_length, head_dimesion] + # -> [batch_size, number_heads, sequence_length, head_dimesion] + main_attn_output = torch.einsum("bntc,bncs->bnts", main_attn_probs, main_value_states) # reshape so that num_heads dim is merged into last `head_dim` axis - main_attn_output = ( - main_attn_output.view(batch_size, self.num_attn_heads, sequence_length, self.head_dim) - .transpose(1, 2) - .reshape(batch_size, 1, sequence_length, hidden_size) - ) + main_attn_output = main_attn_output.transpose(1, 2).reshape(batch_size, 1, sequence_length, hidden_size) main_attn_output = self.out_proj(main_attn_output) # PREDICT-STREAM - # [ngram, B*head, T, c] - predict_query_states = torch.cat(predict_query_states_list, 0).view( - self.ngram, -1, sequence_length, self.head_dim - ) - # [ngram, B*head, 2*T, c] - predict_key_states = torch.cat( - [torch.cat([main_key_states, key], 1).unsqueeze(0) for key in predict_key_states_list], 0 + # [batch_size, ngram, number_heads, sequence_length, head_dimesion] + predict_query_states = torch.stack(predict_query_states_list, 1).view( + batch_size, self.ngram, self.num_attn_heads, sequence_length, self.head_dim ) - # [ngram, T, B, C] - predict_hidden_states = torch.cat(hidden_states_predict_list, 0).view( - self.ngram, sequence_length, batch_size, hidden_size - ) + # [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion] + predict_key_states = torch.stack([torch.cat([main_key_states, key], 2) for key in predict_key_states_list], 1) - # [ngram, B*head, 2*T, c] + # [batch_size, sequence_length, ngram, hidden_size] + predict_hidden_states = torch.stack(hidden_states_predict_list, dim=2) + + # [batch_size, number_heads, ngram, 2*sequence_length, head_dimesion] predict_value_states = torch.cat( - [torch.cat([main_value_states, v_p], 1).unsqueeze(0) for v_p in predict_value_states_list], 0 + [torch.cat([main_value_states, v_p], 2).unsqueeze(2) for v_p in predict_value_states_list], 2 ) - # [ngram, B*head, T, 2*T] - predict_attn_weights = torch.einsum("nbtc,nbsc->nbts", (predict_query_states, predict_key_states)) - # [ngram, B*head, T, S] + # [batch_size, ngram, number_heads, sequence_length, head_dimesion] + # x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion] + # -> [batch_size, ngram, number_heads, sequence_length, 2*sequence_length] + predict_attn_weights = torch.einsum("bnhtc,bnhsc->bnhts", (predict_query_states, predict_key_states)) + # retrieve relative position embeddings for each layer -> see paper for more details + # [batch_size, ngram, number_heads, sequence_length, predict_relative_pos_embeddings] predict_relative_pos_embeddings = self.get_predict_relative_pos_embeddings( predict_hidden_states, predict_attn_weights, position_ids, predict_relative_position_buckets ) - # [ngram, B*head, T, 2*T] + # [batch_size, ngram, number_heads, sequence_length, 2*sequence_length] predict_attn_weights = predict_attn_weights + predict_relative_pos_embeddings if extended_predict_attention_mask is not None: - predict_attn_weights = predict_attn_weights + extended_predict_attention_mask.to( - predict_attn_weights.dtype - ) + # Permuting Predict attention mask to [batch_size, ngram, number_heads, sequence_length, 2*sequence_length] + extended_predict_attention_mask = extended_predict_attention_mask.permute(0, 2, 1, 3, 4) + extended_predict_attention_mask = extended_predict_attention_mask.to(predict_attn_weights.dtype) + predict_attn_weights = predict_attn_weights + extended_predict_attention_mask predict_attn_probs = softmax( predict_attn_weights, @@ -1014,37 +981,30 @@ class XLMProphetNetNgramSelfAttention(nn.Module): f"Head mask for a single layer should be of size {(self.num_attn_heads,)}, but is" f" {layer_head_mask.size()}" ) - predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs.view( - self.ngram, batch_size, self.num_attn_heads, sequence_length, 2 * sequence_length - ) - predict_attn_probs = predict_attn_probs.view( - self.ngram, batch_size * self.num_attn_heads, sequence_length, 2 * sequence_length - ) + predict_attn_probs = layer_head_mask.view(1, 1, -1, 1, 1) * predict_attn_probs predict_attn_probs = nn.functional.dropout( predict_attn_probs, p=self.attention_dropout, training=self.training ) # project to attention output - # [ngram, B*head, T, c] - predict_attn_output = torch.einsum("nbts,nbsc->nbtc", (predict_attn_probs, predict_value_states)) + # [batch_size, ngram, number_heads, sequence_length, 2*sequence_length] + # x [batch_size, ngram, number_heads, 2*sequence_length, head_dimesion] + # -> [batch_size, ngram, number_heads, sequence_length, head_dimesion] + predict_attn_output = torch.einsum( + "bnhts,bnhsc->bnhtc", (predict_attn_probs, predict_value_states.transpose(1, 2)) + ) # reshape so that num_heads dim is merged into last `head_dim` axis - # [ngram, B, T, C] - predict_attn_output = ( - predict_attn_output.view(self.ngram, batch_size, self.num_attn_heads, sequence_length, self.head_dim) - .permute(1, 0, 3, 2, 4) - .reshape(batch_size, self.ngram, sequence_length, hidden_size) - ) + # [batch_size, ngram, number_heads, sequence_length, head_dimesion] -> [batch_size, ngram, sequence_length, hidden_size] + predict_attn_output = predict_attn_output.transpose(2, 3) + predict_attn_output = predict_attn_output.reshape(batch_size, self.ngram, sequence_length, hidden_size) predict_attn_output = self.out_proj(predict_attn_output) # concat to single attn output - # [B, 1+ngram*T, C] + # [batch_size, (1+ngram)*sequence_length, hidden_size] attn_output = torch.cat([main_attn_output, predict_attn_output], 1).view(batch_size, -1, hidden_size) # reshape into better form for `config.output_attentions` main_attn_probs = main_attn_probs.view(batch_size, self.num_attn_heads, sequence_length, -1) - predict_attn_probs = predict_attn_probs.view( - self.ngram, batch_size, self.num_attn_heads, sequence_length, -1 - ).transpose(0, 1) attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training) @@ -1053,8 +1013,11 @@ class XLMProphetNetNgramSelfAttention(nn.Module): def get_main_relative_pos_embeddings( self, hidden_states, attn_weights, position_ids, main_relative_position_buckets ): - # input hidden_states [B,T,C], input attn_weights [T*head,T,S], input position_ids [B,T] or [1,1] - + # input hidden_states [batch_size, sequence_length, hidden_size] + # input attn_weights [batch_size, num_heads, sequence_length, sequence_length] + # input position_ids [batch_size, sequence_length] or [1,1] + batch_size, num_attn_heads, tgt_len, src_len = attn_weights.shape + attn_weights = attn_weights.view(batch_size, num_attn_heads, tgt_len, src_len) if main_relative_position_buckets is None: batch_size, sequence_length = hidden_states.shape[:2] relative_positions = ( @@ -1064,39 +1027,42 @@ class XLMProphetNetNgramSelfAttention(nn.Module): .repeat(batch_size, sequence_length, 1) .to(position_ids.device) ) - relative_positions = relative_positions - position_ids.unsqueeze(0).repeat( - batch_size, sequence_length, 1 - ) # [B, T, s] + # [batch_size, sequence_length, sequence_length+1] + relative_positions = relative_positions - position_ids.unsqueeze(0).repeat(batch_size, sequence_length, 1) main_relative_position_buckets = compute_relative_buckets( self.num_buckets, self.relative_max_distance, relative_positions, False ) - rel_pos_embeddings = self.relative_pos_embeddings(hidden_states) # [B,T,Buckets*head] + # [batch_size, sequence_length, num_buckets * num_heads] + rel_pos_embeddings = self.relative_pos_embeddings(hidden_states) rel_pos_embeddings = rel_pos_embeddings.view( rel_pos_embeddings.shape[:2] + (self.num_buckets, self.num_attn_heads) - ).permute( - 0, 3, 1, 2 - ) # [B,T,Buckets,head] - rel_pos_embeddings = rel_pos_embeddings.reshape(attn_weights.shape[:2] + (-1,)) # [B*head,T,Buckets] + ) + rel_pos_embeddings = rel_pos_embeddings.permute(0, 3, 1, 2) + # [batch_size, num_heads, sequence_length, num_buckets] + rel_pos_embeddings = rel_pos_embeddings.reshape(attn_weights.shape[:3] + (-1,)) - main_relative_position_buckets = ( - main_relative_position_buckets.repeat(1, self.num_attn_heads, 1) - .view(-1, main_relative_position_buckets.shape[-1]) - .long() - ) # [B*head*T, T] - rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1)) # [B*head*T,Buckets] - - main_relative_pos_embeddings = torch.gather( - rel_pos_embeddings, dim=1, index=main_relative_position_buckets - ).view(attn_weights.shape[:2] + (-1,)) + main_relative_position_buckets = main_relative_position_buckets.repeat(1, self.num_attn_heads, 1) + # [batch_size * num_heads * sequence_length, sequence_length] + main_relative_position_buckets = main_relative_position_buckets.view( + -1, main_relative_position_buckets.shape[-1] + ) + main_relative_position_buckets = main_relative_position_buckets.long() + # [batch_size * num_heads * sequence_length, sequence_length] + rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1)) + main_relative_pos_embeddings = torch.gather(rel_pos_embeddings, dim=1, index=main_relative_position_buckets) + main_relative_pos_embeddings = main_relative_pos_embeddings.view(batch_size, num_attn_heads, tgt_len, -1) return main_relative_pos_embeddings def get_predict_relative_pos_embeddings( self, hidden_states, attn_weights, position_ids, predict_relative_position_buckets ): - # input hidden_states [ngram, T,B,C], input attn_weights [ngram, B*head,T,S], input position_ids [B,T] or [1,1], input predict_relative_position_buckets [B,T, 2*T] or None - sequence_length, batch_size = hidden_states.shape[1:3] + # input hidden_states [batch_size, sequence_length, ngram, hidden_size] + # input attn_weights [batch_size, ngram, num_heads, sequence_length, 2*sequence_length] + # input position_ids [batch_size, sequence_length] or [1,1] + # input predict_relative_position_buckets [batch_size, sequence_length, 2*sequence_length] or None + batch_size, sequence_length = hidden_states.shape[0:2] if predict_relative_position_buckets is None: key_sequence_length = attn_weights.shape[-1] @@ -1116,28 +1082,35 @@ class XLMProphetNetNgramSelfAttention(nn.Module): self.num_buckets, self.relative_max_distance, relative_positions, False ) - hidden_states = hidden_states.transpose(1, 2) # [ngram, B, T, C] - rel_pos_embeddings = self.relative_pos_embeddings(hidden_states).view( + # [batch_size, ngram, sequence_length, hidden_size] + hidden_states = hidden_states.transpose(1, 2) + rel_pos_embeddings = self.relative_pos_embeddings(hidden_states) + + # [batch_size, ngram, sequence_length, num_buckets, num_heads] + rel_pos_embeddings = rel_pos_embeddings.view( hidden_states.shape[:-1] + (self.num_buckets, self.num_attn_heads) - ) # [ngram, B, T, bucket, head] - rel_pos_embeddings = rel_pos_embeddings.permute(0, 1, 4, 2, 3).reshape( - self.ngram * batch_size * self.num_attn_heads, sequence_length, -1 - ) # [ngram*B*head, T, bucket] - - predict_relative_position_buckets = predict_relative_position_buckets.unsqueeze(0).repeat( + ) + rel_pos_embeddings = rel_pos_embeddings.permute(0, 2, 1, 4, 3) + # [batch_size * ngram * sequence_length * num_heads, num_buckets] + rel_pos_embeddings = rel_pos_embeddings.reshape(-1, self.num_buckets) + # [ngram, batch_size, num_heads * sequence_length, -1] + predict_relative_position_buckets = predict_relative_position_buckets.unsqueeze(0) + predict_relative_position_buckets = predict_relative_position_buckets.repeat( self.ngram, 1, self.num_attn_heads, 1 - ) # [ngram, B, head*T, S] - - rel_pos_embeddings = rel_pos_embeddings.reshape(-1, rel_pos_embeddings.size(-1)) + ) + # [ngram * batch_size * num_heads * sequence_length, -1] predict_relative_position_buckets = predict_relative_position_buckets.view( -1, predict_relative_position_buckets.size(-1) - ).long() # [ngram*B*head*T, S] + ).long() predict_relative_pos_embeddings = torch.gather( rel_pos_embeddings, dim=1, index=predict_relative_position_buckets - ).view( - self.ngram, batch_size * self.num_attn_heads, sequence_length, -1 - ) # [ngram, B*head, T, S] + ) + + # [batch_size, gram, num_heads, sequence_length, -1] + predict_relative_pos_embeddings = predict_relative_pos_embeddings.view( + batch_size, self.ngram, self.num_attn_heads, sequence_length, -1 + ) return predict_relative_pos_embeddings @@ -1351,7 +1324,7 @@ class XLMProphetNetEncoder(XLMProphetNetPreTrainedModel): # prepare attention mask if attention_mask is not None: extended_attention_mask = ( - 1.0 - attention_mask[:, None, :].repeat(self.config.num_encoder_attention_heads, 1, 1) + 1.0 - attention_mask[:, None, None, :].repeat(1, self.config.num_encoder_attention_heads, 1, 1) ) * torch.finfo(self.dtype).min extended_attention_mask = extended_attention_mask.to(inputs_embeds.dtype) else: @@ -1572,7 +1545,7 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel): # prepare encoder attention mask if encoder_attention_mask is not None: extended_encoder_attention_mask = ( - 1.0 - encoder_attention_mask[:, None, :].repeat(self.config.num_decoder_attention_heads, 1, 1) + 1.0 - encoder_attention_mask[:, None, None, :].repeat(1, self.config.num_decoder_attention_heads, 1, 1) ) * torch.finfo(self.dtype).min extended_encoder_attention_mask = extended_encoder_attention_mask.to(inputs_embeds.dtype) else: @@ -1740,17 +1713,18 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel): device=hidden_states.device, ) causal_mask = torch.triu(causal_mask, 1) - extended_causal_mask = causal_mask[:seq_length, :seq_length][None, :, :].expand( - (batch_size,) + causal_mask.shape + + extended_causal_mask = causal_mask[:seq_length, :seq_length][None, None, :, :].expand( + (batch_size, self.config.num_decoder_attention_heads) + causal_mask.shape ) # add usual attention mask if attention_mask is not None: - extended_attention_mask = (1.0 - attention_mask[:, None, :]) * torch.finfo(self.dtype).min + extended_attention_mask = (1.0 - attention_mask[:, None, None, :]) * torch.finfo(self.dtype).min extended_attention_mask = extended_causal_mask + extended_attention_mask else: extended_attention_mask = extended_causal_mask - return extended_attention_mask.repeat(self.config.num_decoder_attention_heads, 1, 1).to(hidden_states.dtype) + return extended_attention_mask.to(hidden_states.dtype) def prepare_predict_attention_mask(self, hidden_states, attention_mask): batch_size, seq_length = hidden_states.shape[:2] @@ -1768,14 +1742,16 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel): ], dim=-1, ) - extended_predict_causal_mask = predict_causal_mask[:, None, :, :].expand( - predict_causal_mask.shape[:1] + (batch_size,) + predict_causal_mask.shape[1:] + extended_predict_causal_mask = predict_causal_mask[None, None, :, :, :].expand( + (batch_size, self.config.num_decoder_attention_heads) + predict_causal_mask.shape ) # add usual attention mask if attention_mask is not None: - extended_attention_mask = (1.0 - attention_mask[None, :, None, :]) * torch.finfo(self.dtype).min - extended_attention_mask = extended_attention_mask.expand((self.ngram, batch_size, seq_length, seq_length)) + extended_attention_mask = (1.0 - attention_mask[:, None, None, None, :]) * torch.finfo(self.dtype).min + extended_attention_mask = extended_attention_mask.expand( + (batch_size, self.config.num_decoder_attention_heads, self.ngram, seq_length, seq_length) + ) # predicted stream attention_mask should always be 0 extended_attention_mask = torch.cat( [extended_attention_mask, torch.zeros_like(extended_attention_mask)], dim=-1 @@ -1783,9 +1759,7 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel): extended_predict_attention_mask = extended_predict_causal_mask + extended_attention_mask else: extended_predict_attention_mask = extended_predict_causal_mask - return extended_predict_attention_mask.repeat(1, self.config.num_decoder_attention_heads, 1, 1).to( - hidden_states.dtype - ) + return extended_predict_attention_mask.to(hidden_states.dtype) @add_start_docstrings( diff --git a/tests/models/prophetnet/test_modeling_prophetnet.py b/tests/models/prophetnet/test_modeling_prophetnet.py index 36d5da6836..2579be5396 100644 --- a/tests/models/prophetnet/test_modeling_prophetnet.py +++ b/tests/models/prophetnet/test_modeling_prophetnet.py @@ -1206,7 +1206,7 @@ class ProphetNetModelIntegrationTest(unittest.TestCase): expected_shape = torch.Size((1, 12, 30522)) self.assertEqual(output_predited_logits.shape, expected_shape) expected_slice = torch.tensor( - [[[-7.6213, -7.9008, -7.9979], [-7.6834, -7.8467, -8.2187], [-7.5326, -7.4762, -8.1914]]] + [[[-7.7729, -8.0343, -8.26001], [-7.74213, -7.8629, -8.6000], [-7.7328, -7.8269, -8.5264]]] ).to(torch_device) # self.assertTrue(torch.allclose(output_predited_logits[:, :3, :3], expected_slice, atol=1e-4)) assert torch.allclose(output_predited_logits[:, :3, :3], expected_slice, atol=1e-4) @@ -1306,7 +1306,7 @@ class ProphetNetModelIntegrationTest(unittest.TestCase): EXPECTED_QUESTIONS = [ "along with paul allen, who founded microsoft?", "what year was microsoft founded?", - "on what date was microsoft founded?", + "when was microsoft founded?", ] self.assertListEqual(