update model to use past
This commit is contained in:
@@ -52,7 +52,7 @@ def positional_encoding(position, d_model_size, dtype):
|
|||||||
sines = torch.sin(angle_rads[:, 0::2])
|
sines = torch.sin(angle_rads[:, 0::2])
|
||||||
cosines = torch.cos(angle_rads[:, 1::2])
|
cosines = torch.cos(angle_rads[:, 1::2])
|
||||||
|
|
||||||
pos_encoding = torch.cat([sines, cosines], dim=-1).unsqueeze(0)
|
pos_encoding = torch.cat([sines, cosines], dim=-1)
|
||||||
return pos_encoding
|
return pos_encoding
|
||||||
|
|
||||||
def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=None):
|
def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=None):
|
||||||
@@ -110,18 +110,21 @@ class MultiHeadAttention(torch.nn.Module):
|
|||||||
k = self.split_into_heads(k, batch_size)
|
k = self.split_into_heads(k, batch_size)
|
||||||
v = self.split_into_heads(v, batch_size)
|
v = self.split_into_heads(v, batch_size)
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
|
past_key, past_value = layer_past[0], layer_past[1]
|
||||||
k = torch.cat((past_key, k), dim=-1)
|
k = torch.cat((past_key, k), dim=-1)
|
||||||
v = torch.cat((past_value, v), dim=-2)
|
v = torch.cat((past_value, v), dim=-2)
|
||||||
present = torch.stack((k.transpose(-2, -1), v)) # transpose to have same shapes for stacking
|
present = torch.stack((k, v))
|
||||||
|
|
||||||
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask, output_attentions)
|
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
|
||||||
scaled_attention = output[0].permute([0, 2, 1, 3])
|
scaled_attention = output[0].permute([0, 2, 1, 3])
|
||||||
attn = output[1]
|
attn = output[1]
|
||||||
original_size_attention = scaled_attention.reshape(batch_size, -1, self.d_model_size)
|
original_size_attention = scaled_attention.reshape(batch_size, -1, self.d_model_size)
|
||||||
output = self.dense(original_size_attention)
|
output = self.dense(original_size_attention)
|
||||||
|
|
||||||
return output, attn
|
outputs = (output, present)
|
||||||
|
if self.output_attentions:
|
||||||
|
outputs = outputs + (attn,)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -146,10 +149,11 @@ class EncoderLayer(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, mask, layer_past=None, attention_mask=None, head_mask=None):
|
def forward(self, x, mask, layer_past=None, attention_mask=None, head_mask=None):
|
||||||
normed = self.layernorm1(x)
|
normed = self.layernorm1(x)
|
||||||
attn_output, attn = self.multi_head_attention(normed, normed, normed, mask,
|
attn_outputs = self.multi_head_attention(normed, normed, normed, mask,
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask)
|
head_mask=head_mask)
|
||||||
|
attn_output = attn_outputs[0]
|
||||||
attn_output = self.dropout1(attn_output)
|
attn_output = self.dropout1(attn_output)
|
||||||
out1 = x + attn_output
|
out1 = x + attn_output
|
||||||
|
|
||||||
@@ -158,7 +162,8 @@ class EncoderLayer(torch.nn.Module):
|
|||||||
ffn_output = self.dropout2(ffn_output)
|
ffn_output = self.dropout2(ffn_output)
|
||||||
out2 = out1 + ffn_output
|
out2 = out1 + ffn_output
|
||||||
|
|
||||||
return out2, attn
|
outputs = (out2,) + attn_outputs[1:]
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class CTRLPreTrainedModel(PreTrainedModel):
|
class CTRLPreTrainedModel(PreTrainedModel):
|
||||||
@@ -344,14 +349,15 @@ class CTRLModel(CTRLPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
head_mask = [None] * self.config.n_layer
|
head_mask = [None] * self.config.n_layer
|
||||||
|
|
||||||
embedded = self.w(input_ids)
|
x = self.w(input_ids)
|
||||||
x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
|
# x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
|
||||||
seq_len = input_ids.shape[1]
|
seq_len = input_ids.shape[1]
|
||||||
mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(x.device)
|
mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(x.device)
|
||||||
|
|
||||||
x *= np.sqrt(self.d_model_size)
|
x *= np.sqrt(self.d_model_size)
|
||||||
|
|
||||||
x += self.pos_encoding[:, position_ids, :].to(x.device)
|
pos_x = self.pos_encoding[position_ids, :].to(x.device)
|
||||||
|
x += pos_x
|
||||||
|
|
||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
|
|
||||||
|
|||||||
@@ -144,14 +144,16 @@ class CTRLModelTest(CommonTestCases.CommonModelTester):
|
|||||||
|
|
||||||
model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
|
model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
|
||||||
model(input_ids, token_type_ids=token_type_ids)
|
model(input_ids, token_type_ids=token_type_ids)
|
||||||
sequence_output, _ = model(input_ids)
|
sequence_output, presents = model(input_ids)
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"sequence_output": sequence_output,
|
"sequence_output": sequence_output,
|
||||||
|
"presents": presents,
|
||||||
}
|
}
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["sequence_output"].size()),
|
list(result["sequence_output"].size()),
|
||||||
[self.batch_size, self.seq_length, self.hidden_size])
|
[self.batch_size, self.seq_length, self.hidden_size])
|
||||||
|
self.parent.assertEqual(len(result["presents"]), config.n_layer)
|
||||||
|
|
||||||
def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||||
model = CTRLLMHeadModel(config)
|
model = CTRLLMHeadModel(config)
|
||||||
|
|||||||
Reference in New Issue
Block a user