output all hidden layers states in GPT/GPT-2

This commit is contained in:
thomwolf
2019-06-17 14:34:12 +02:00
parent f12007e421
commit 965f172de6
4 changed files with 43 additions and 12 deletions

View File

@@ -716,7 +716,10 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
hidden_states = inputs_embeds + position_embeds + token_type_embeds
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
all_attentions = []
all_hidden_states = [hidden_states.view(*output_shape)]
for block in self.h:
outputs = block(hidden_states, head_mask)
if self.output_attentions:
@@ -724,10 +727,11 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
all_attentions.append(attentions)
else:
hidden_states = outputs
output_shape = input_shape + (hidden_states.size(-1),)
all_hidden_states.append(hidden_states.view(*output_shape))
if self.output_attentions:
return all_attentions, hidden_states.view(*output_shape)
return hidden_states.view(*output_shape)
return all_attentions, all_hidden_states
return all_hidden_states
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
@@ -805,6 +809,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
if self.transformer.output_attentions:
all_attentions, hidden_states = hidden_states
hidden_states = hidden_states[-1]
lm_logits = self.lm_head(hidden_states)
if lm_labels is not None:
# Shift so that tokens < n predict n
@@ -902,6 +908,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
if self.transformer.output_attentions:
all_attentions, hidden_states = hidden_states
hidden_states = hidden_states[-1]
lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
losses = []