From 3edfa1d6aaf30247c413fa15f04758b96d04762c Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 8 Oct 2019 17:11:58 +0200 Subject: [PATCH] update model to use past --- transformers/modeling_ctrl.py | 26 +++++++++++++++--------- transformers/tests/modeling_ctrl_test.py | 4 +++- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/transformers/modeling_ctrl.py b/transformers/modeling_ctrl.py index cef2a666a4..a7b67f0674 100644 --- a/transformers/modeling_ctrl.py +++ b/transformers/modeling_ctrl.py @@ -52,7 +52,7 @@ def positional_encoding(position, d_model_size, dtype): sines = torch.sin(angle_rads[:, 0::2]) cosines = torch.cos(angle_rads[:, 1::2]) - pos_encoding = torch.cat([sines, cosines], dim=-1).unsqueeze(0) + pos_encoding = torch.cat([sines, cosines], dim=-1) return pos_encoding def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=None): @@ -110,18 +110,21 @@ class MultiHeadAttention(torch.nn.Module): k = self.split_into_heads(k, batch_size) v = self.split_into_heads(v, batch_size) if layer_past is not None: - past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below + past_key, past_value = layer_past[0], layer_past[1] k = torch.cat((past_key, k), dim=-1) v = torch.cat((past_value, v), dim=-2) - present = torch.stack((k.transpose(-2, -1), v)) # transpose to have same shapes for stacking + present = torch.stack((k, v)) - output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask, output_attentions) + output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask) scaled_attention = output[0].permute([0, 2, 1, 3]) attn = output[1] original_size_attention = scaled_attention.reshape(batch_size, -1, self.d_model_size) output = self.dense(original_size_attention) - return output, attn + outputs = (output, present) + if self.output_attentions: + outputs = outputs + (attn,) + return outputs @@ -146,10 +149,11 @@ class EncoderLayer(torch.nn.Module): def forward(self, x, mask, layer_past=None, attention_mask=None, head_mask=None): normed = self.layernorm1(x) - attn_output, attn = self.multi_head_attention(normed, normed, normed, mask, + attn_outputs = self.multi_head_attention(normed, normed, normed, mask, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask) + attn_output = attn_outputs[0] attn_output = self.dropout1(attn_output) out1 = x + attn_output @@ -158,7 +162,8 @@ class EncoderLayer(torch.nn.Module): ffn_output = self.dropout2(ffn_output) out2 = out1 + ffn_output - return out2, attn + outputs = (out2,) + attn_outputs[1:] + return outputs class CTRLPreTrainedModel(PreTrainedModel): @@ -344,14 +349,15 @@ class CTRLModel(CTRLPreTrainedModel): else: head_mask = [None] * self.config.n_layer - embedded = self.w(input_ids) - x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded + x = self.w(input_ids) + # x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded seq_len = input_ids.shape[1] mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(x.device) x *= np.sqrt(self.d_model_size) - x += self.pos_encoding[:, position_ids, :].to(x.device) + pos_x = self.pos_encoding[position_ids, :].to(x.device) + x += pos_x x = self.dropout(x) diff --git a/transformers/tests/modeling_ctrl_test.py b/transformers/tests/modeling_ctrl_test.py index ac7c32b113..47ff8d8d51 100644 --- a/transformers/tests/modeling_ctrl_test.py +++ b/transformers/tests/modeling_ctrl_test.py @@ -144,14 +144,16 @@ class CTRLModelTest(CommonTestCases.CommonModelTester): model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask) model(input_ids, token_type_ids=token_type_ids) - sequence_output, _ = model(input_ids) + sequence_output, presents = model(input_ids) result = { "sequence_output": sequence_output, + "presents": presents, } self.parent.assertListEqual( list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]) + self.parent.assertEqual(len(result["presents"]), config.n_layer) def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): model = CTRLLMHeadModel(config)