update model to use past
This commit is contained in:
@@ -52,7 +52,7 @@ def positional_encoding(position, d_model_size, dtype):
|
||||
sines = torch.sin(angle_rads[:, 0::2])
|
||||
cosines = torch.cos(angle_rads[:, 1::2])
|
||||
|
||||
pos_encoding = torch.cat([sines, cosines], dim=-1).unsqueeze(0)
|
||||
pos_encoding = torch.cat([sines, cosines], dim=-1)
|
||||
return pos_encoding
|
||||
|
||||
def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=None):
|
||||
@@ -110,18 +110,21 @@ class MultiHeadAttention(torch.nn.Module):
|
||||
k = self.split_into_heads(k, batch_size)
|
||||
v = self.split_into_heads(v, batch_size)
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
|
||||
past_key, past_value = layer_past[0], layer_past[1]
|
||||
k = torch.cat((past_key, k), dim=-1)
|
||||
v = torch.cat((past_value, v), dim=-2)
|
||||
present = torch.stack((k.transpose(-2, -1), v)) # transpose to have same shapes for stacking
|
||||
present = torch.stack((k, v))
|
||||
|
||||
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask, output_attentions)
|
||||
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
|
||||
scaled_attention = output[0].permute([0, 2, 1, 3])
|
||||
attn = output[1]
|
||||
original_size_attention = scaled_attention.reshape(batch_size, -1, self.d_model_size)
|
||||
output = self.dense(original_size_attention)
|
||||
|
||||
return output, attn
|
||||
outputs = (output, present)
|
||||
if self.output_attentions:
|
||||
outputs = outputs + (attn,)
|
||||
return outputs
|
||||
|
||||
|
||||
|
||||
@@ -146,10 +149,11 @@ class EncoderLayer(torch.nn.Module):
|
||||
|
||||
def forward(self, x, mask, layer_past=None, attention_mask=None, head_mask=None):
|
||||
normed = self.layernorm1(x)
|
||||
attn_output, attn = self.multi_head_attention(normed, normed, normed, mask,
|
||||
attn_outputs = self.multi_head_attention(normed, normed, normed, mask,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask)
|
||||
attn_output = attn_outputs[0]
|
||||
attn_output = self.dropout1(attn_output)
|
||||
out1 = x + attn_output
|
||||
|
||||
@@ -158,7 +162,8 @@ class EncoderLayer(torch.nn.Module):
|
||||
ffn_output = self.dropout2(ffn_output)
|
||||
out2 = out1 + ffn_output
|
||||
|
||||
return out2, attn
|
||||
outputs = (out2,) + attn_outputs[1:]
|
||||
return outputs
|
||||
|
||||
|
||||
class CTRLPreTrainedModel(PreTrainedModel):
|
||||
@@ -344,14 +349,15 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
else:
|
||||
head_mask = [None] * self.config.n_layer
|
||||
|
||||
embedded = self.w(input_ids)
|
||||
x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
|
||||
x = self.w(input_ids)
|
||||
# x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
|
||||
seq_len = input_ids.shape[1]
|
||||
mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(x.device)
|
||||
|
||||
x *= np.sqrt(self.d_model_size)
|
||||
|
||||
x += self.pos_encoding[:, position_ids, :].to(x.device)
|
||||
pos_x = self.pos_encoding[position_ids, :].to(x.device)
|
||||
x += pos_x
|
||||
|
||||
x = self.dropout(x)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user