From 965f172de6aa52369500a4c6bc76244f69272c0f Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 17 Jun 2019 14:34:12 +0200 Subject: [PATCH] output all hidden layers states in GPT/GPT-2 --- pytorch_pretrained_bert/modeling_gpt2.py | 15 ++++++++++++--- pytorch_pretrained_bert/modeling_openai.py | 14 +++++++++++--- tests/modeling_gpt2_test.py | 13 ++++++++++--- tests/modeling_openai_test.py | 13 ++++++++++--- 4 files changed, 43 insertions(+), 12 deletions(-) diff --git a/pytorch_pretrained_bert/modeling_gpt2.py b/pytorch_pretrained_bert/modeling_gpt2.py index e9fc1c5f98..9240ea2bd0 100644 --- a/pytorch_pretrained_bert/modeling_gpt2.py +++ b/pytorch_pretrained_bert/modeling_gpt2.py @@ -720,9 +720,13 @@ class GPT2Model(GPT2PreTrainedModel): hidden_states = inputs_embeds + position_embeds + token_type_embeds hidden_states = self.drop(hidden_states) + output_shape = input_shape + (hidden_states.size(-1),) + presents = [] all_attentions = [] + all_hidden_states = [] for block, layer_past in zip(self.h, past): + all_hidden_states.append(hidden_states.view(*output_shape)) outputs = block(hidden_states, layer_past, head_mask) if self.output_attentions: attentions, hidden_states, present = outputs @@ -731,10 +735,11 @@ class GPT2Model(GPT2PreTrainedModel): hidden_states, present = outputs presents.append(present) hidden_states = self.ln_f(hidden_states) - output_shape = input_shape + (hidden_states.size(-1),) + all_hidden_states.append(hidden_states.view(*output_shape)) + if self.output_attentions: - return all_attentions, hidden_states.view(*output_shape), presents - return hidden_states.view(*output_shape), presents + return all_attentions, all_hidden_states, presents + return all_hidden_states, presents class GPT2LMHeadModel(GPT2PreTrainedModel): @@ -802,6 +807,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): all_attentions, hidden_states, presents = transformer_output else: hidden_states, presents = transformer_output + hidden_states = hidden_states[-1] + lm_logits = self.lm_head(hidden_states) if lm_labels is not None: # Shift so that tokens < n predict n @@ -889,6 +896,8 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): all_attentions, hidden_states, presents = transformer_output else: hidden_states, presents = transformer_output + hidden_states = hidden_states[-1] + lm_logits = self.lm_head(hidden_states) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) losses = [] diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index bbc60ffd2c..32c0978dd0 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -716,7 +716,10 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): hidden_states = inputs_embeds + position_embeds + token_type_embeds hidden_states = self.drop(hidden_states) + output_shape = input_shape + (hidden_states.size(-1),) + all_attentions = [] + all_hidden_states = [hidden_states.view(*output_shape)] for block in self.h: outputs = block(hidden_states, head_mask) if self.output_attentions: @@ -724,10 +727,11 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): all_attentions.append(attentions) else: hidden_states = outputs - output_shape = input_shape + (hidden_states.size(-1),) + all_hidden_states.append(hidden_states.view(*output_shape)) + if self.output_attentions: - return all_attentions, hidden_states.view(*output_shape) - return hidden_states.view(*output_shape) + return all_attentions, all_hidden_states + return all_hidden_states class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): @@ -805,6 +809,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): hidden_states = self.transformer(input_ids, position_ids, token_type_ids, head_mask) if self.transformer.output_attentions: all_attentions, hidden_states = hidden_states + hidden_states = hidden_states[-1] + lm_logits = self.lm_head(hidden_states) if lm_labels is not None: # Shift so that tokens < n predict n @@ -902,6 +908,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): hidden_states = self.transformer(input_ids, position_ids, token_type_ids, head_mask) if self.transformer.output_attentions: all_attentions, hidden_states = hidden_states + hidden_states = hidden_states[-1] + lm_logits = self.lm_head(hidden_states) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) losses = [] diff --git a/tests/modeling_gpt2_test.py b/tests/modeling_gpt2_test.py index aaa88d54e8..c9a7a64b5a 100644 --- a/tests/modeling_gpt2_test.py +++ b/tests/modeling_gpt2_test.py @@ -115,8 +115,9 @@ class GPT2ModelTest(unittest.TestCase): return outputs def check_gpt2_model_output(self, result): + self.parent.assertEqual(len(result["hidden_states"]), self.n_layer + 1) self.parent.assertListEqual( - list(result["hidden_states"].size()), + list(result["hidden_states"][0].size()), [self.batch_size, self.n_choices, self.seq_length, self.n_embd]) @@ -222,7 +223,10 @@ class GPT2ModelTest(unittest.TestCase): else: output = model(input_ids, head_mask=head_mask) - output = sum(t.sum() for t in output[:-1]) + if isinstance(model, GPT2Model): + output = sum(t.sum() for t in output[0]) + elif isinstance(output, (list, tuple)): + output = sum(t.sum() for t in output[:-1]) output = output.sum() output.backward() multihead_outputs = (model if isinstance(model, GPT2Model) else model.transformer).get_multihead_outputs() @@ -256,7 +260,10 @@ class GPT2ModelTest(unittest.TestCase): else: output = model(input_ids) - output = sum(t.sum() for t in output[:-1]) + if isinstance(model, GPT2Model): + output = sum(t.sum() for t in output[0]) + elif isinstance(output, (list, tuple)): + output = sum(t.sum() for t in output[:-1]) output = output.sum() output.backward() multihead_outputs = transformer.get_multihead_outputs() diff --git a/tests/modeling_openai_test.py b/tests/modeling_openai_test.py index 08353cdd18..86234e57ca 100644 --- a/tests/modeling_openai_test.py +++ b/tests/modeling_openai_test.py @@ -125,8 +125,9 @@ class OpenAIGPTModelTest(unittest.TestCase): return outputs def check_openai_model_output(self, result): + self.parent.assertEqual(len(result["hidden_states"]), self.n_layer + 1) self.parent.assertListEqual( - list(result["hidden_states"].size()), + list(result["hidden_states"][0].size()), [self.batch_size, self.n_choices, self.seq_length, self.n_embd]) @@ -195,7 +196,10 @@ class OpenAIGPTModelTest(unittest.TestCase): else: output = model(input_ids, head_mask=head_mask) - output = sum(t.sum() for t in output[:-1]) + if isinstance(model, OpenAIGPTModel): + output = sum(t.sum() for t in output[0]) + elif isinstance(output, (list, tuple)): + output = sum(t.sum() for t in output) output = output.sum() output.backward() multihead_outputs = (model if isinstance(model, OpenAIGPTModel) else model.transformer).get_multihead_outputs() @@ -229,7 +233,10 @@ class OpenAIGPTModelTest(unittest.TestCase): else: output = model(input_ids) - output = sum(t.sum() for t in output[:-1]) + if isinstance(model, OpenAIGPTModel): + output = sum(t.sum() for t in output[0]) + elif isinstance(output, (list, tuple)): + output = sum(t.sum() for t in output) output = output.sum() output.backward() multihead_outputs = transformer.get_multihead_outputs()