Fix CTRL past
This commit is contained in:
@@ -63,7 +63,8 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N
|
|||||||
scaled_attention_logits = matmul_qk / np.sqrt(dk)
|
scaled_attention_logits = matmul_qk / np.sqrt(dk)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
scaled_attention_logits += (mask * -1e4)
|
nd, ns = scaled_attention_logits.size(-2), scaled_attention_logits.size(-1)
|
||||||
|
scaled_attention_logits += (mask[ns-nd:ns, :ns] * -1e4)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# Apply the attention mask
|
# Apply the attention mask
|
||||||
@@ -357,7 +358,7 @@ class CTRLModel(CTRLPreTrainedModel):
|
|||||||
inputs_embeds = self.w(input_ids)
|
inputs_embeds = self.w(input_ids)
|
||||||
# inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
|
# inputs_embeds = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
|
||||||
seq_len = input_ids.shape[-1]
|
seq_len = input_ids.shape[-1]
|
||||||
mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(inputs_embeds.device)
|
mask = torch.triu(torch.ones(seq_len + past_length, seq_len + past_length), 1).to(inputs_embeds.device)
|
||||||
|
|
||||||
inputs_embeds *= np.sqrt(self.d_model_size)
|
inputs_embeds *= np.sqrt(self.d_model_size)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user