output all hidden layers states in GPT/GPT-2
This commit is contained in:
@@ -716,7 +716,10 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
||||
hidden_states = self.drop(hidden_states)
|
||||
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
|
||||
all_attentions = []
|
||||
all_hidden_states = [hidden_states.view(*output_shape)]
|
||||
for block in self.h:
|
||||
outputs = block(hidden_states, head_mask)
|
||||
if self.output_attentions:
|
||||
@@ -724,10 +727,11 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
all_attentions.append(attentions)
|
||||
else:
|
||||
hidden_states = outputs
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
all_hidden_states.append(hidden_states.view(*output_shape))
|
||||
|
||||
if self.output_attentions:
|
||||
return all_attentions, hidden_states.view(*output_shape)
|
||||
return hidden_states.view(*output_shape)
|
||||
return all_attentions, all_hidden_states
|
||||
return all_hidden_states
|
||||
|
||||
|
||||
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
@@ -805,6 +809,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
hidden_states = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
|
||||
if self.transformer.output_attentions:
|
||||
all_attentions, hidden_states = hidden_states
|
||||
hidden_states = hidden_states[-1]
|
||||
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
if lm_labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
@@ -902,6 +908,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
hidden_states = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
|
||||
if self.transformer.output_attentions:
|
||||
all_attentions, hidden_states = hidden_states
|
||||
hidden_states = hidden_states[-1]
|
||||
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
|
||||
losses = []
|
||||
|
||||
Reference in New Issue
Block a user