From 151e4ab4e786b9b4b702205b5077ea2dfe67b4dd Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Tue, 5 Nov 2019 16:26:51 +0000 Subject: [PATCH] Fix CTRL past --- transformers/modeling_ctrl.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformers/modeling_ctrl.py b/transformers/modeling_ctrl.py index 1873040a8e..589a065a11 100644 --- a/transformers/modeling_ctrl.py +++ b/transformers/modeling_ctrl.py @@ -63,7 +63,8 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N scaled_attention_logits = matmul_qk / np.sqrt(dk) if mask is not None: - scaled_attention_logits += (mask * -1e4) + nd, ns = scaled_attention_logits.size(-2), scaled_attention_logits.size(-1) + scaled_attention_logits += (mask[ns-nd:ns, :ns] * -1e4) if attention_mask is not None: # Apply the attention mask @@ -357,7 +358,7 @@ class CTRLModel(CTRLPreTrainedModel): inputs_embeds = self.w(input_ids) # inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded seq_len = input_ids.shape[-1] - mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(inputs_embeds.device) + mask = torch.triu(torch.ones(seq_len + past_length, seq_len + past_length), 1).to(inputs_embeds.device) inputs_embeds *= np.sqrt(self.d_model_size)