output all hidden layers states in GPT/GPT-2
This commit is contained in:
@@ -720,9 +720,13 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
||||||
hidden_states = self.drop(hidden_states)
|
hidden_states = self.drop(hidden_states)
|
||||||
|
|
||||||
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
|
|
||||||
presents = []
|
presents = []
|
||||||
all_attentions = []
|
all_attentions = []
|
||||||
|
all_hidden_states = []
|
||||||
for block, layer_past in zip(self.h, past):
|
for block, layer_past in zip(self.h, past):
|
||||||
|
all_hidden_states.append(hidden_states.view(*output_shape))
|
||||||
outputs = block(hidden_states, layer_past, head_mask)
|
outputs = block(hidden_states, layer_past, head_mask)
|
||||||
if self.output_attentions:
|
if self.output_attentions:
|
||||||
attentions, hidden_states, present = outputs
|
attentions, hidden_states, present = outputs
|
||||||
@@ -731,10 +735,11 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
hidden_states, present = outputs
|
hidden_states, present = outputs
|
||||||
presents.append(present)
|
presents.append(present)
|
||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
output_shape = input_shape + (hidden_states.size(-1),)
|
all_hidden_states.append(hidden_states.view(*output_shape))
|
||||||
|
|
||||||
if self.output_attentions:
|
if self.output_attentions:
|
||||||
return all_attentions, hidden_states.view(*output_shape), presents
|
return all_attentions, all_hidden_states, presents
|
||||||
return hidden_states.view(*output_shape), presents
|
return all_hidden_states, presents
|
||||||
|
|
||||||
|
|
||||||
class GPT2LMHeadModel(GPT2PreTrainedModel):
|
class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||||
@@ -802,6 +807,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
all_attentions, hidden_states, presents = transformer_output
|
all_attentions, hidden_states, presents = transformer_output
|
||||||
else:
|
else:
|
||||||
hidden_states, presents = transformer_output
|
hidden_states, presents = transformer_output
|
||||||
|
hidden_states = hidden_states[-1]
|
||||||
|
|
||||||
lm_logits = self.lm_head(hidden_states)
|
lm_logits = self.lm_head(hidden_states)
|
||||||
if lm_labels is not None:
|
if lm_labels is not None:
|
||||||
# Shift so that tokens < n predict n
|
# Shift so that tokens < n predict n
|
||||||
@@ -889,6 +896,8 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|||||||
all_attentions, hidden_states, presents = transformer_output
|
all_attentions, hidden_states, presents = transformer_output
|
||||||
else:
|
else:
|
||||||
hidden_states, presents = transformer_output
|
hidden_states, presents = transformer_output
|
||||||
|
hidden_states = hidden_states[-1]
|
||||||
|
|
||||||
lm_logits = self.lm_head(hidden_states)
|
lm_logits = self.lm_head(hidden_states)
|
||||||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
|
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
|
||||||
losses = []
|
losses = []
|
||||||
|
|||||||
@@ -716,7 +716,10 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
|||||||
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
||||||
hidden_states = self.drop(hidden_states)
|
hidden_states = self.drop(hidden_states)
|
||||||
|
|
||||||
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
|
|
||||||
all_attentions = []
|
all_attentions = []
|
||||||
|
all_hidden_states = [hidden_states.view(*output_shape)]
|
||||||
for block in self.h:
|
for block in self.h:
|
||||||
outputs = block(hidden_states, head_mask)
|
outputs = block(hidden_states, head_mask)
|
||||||
if self.output_attentions:
|
if self.output_attentions:
|
||||||
@@ -724,10 +727,11 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
|||||||
all_attentions.append(attentions)
|
all_attentions.append(attentions)
|
||||||
else:
|
else:
|
||||||
hidden_states = outputs
|
hidden_states = outputs
|
||||||
output_shape = input_shape + (hidden_states.size(-1),)
|
all_hidden_states.append(hidden_states.view(*output_shape))
|
||||||
|
|
||||||
if self.output_attentions:
|
if self.output_attentions:
|
||||||
return all_attentions, hidden_states.view(*output_shape)
|
return all_attentions, all_hidden_states
|
||||||
return hidden_states.view(*output_shape)
|
return all_hidden_states
|
||||||
|
|
||||||
|
|
||||||
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||||
@@ -805,6 +809,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
|||||||
hidden_states = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
|
hidden_states = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
|
||||||
if self.transformer.output_attentions:
|
if self.transformer.output_attentions:
|
||||||
all_attentions, hidden_states = hidden_states
|
all_attentions, hidden_states = hidden_states
|
||||||
|
hidden_states = hidden_states[-1]
|
||||||
|
|
||||||
lm_logits = self.lm_head(hidden_states)
|
lm_logits = self.lm_head(hidden_states)
|
||||||
if lm_labels is not None:
|
if lm_labels is not None:
|
||||||
# Shift so that tokens < n predict n
|
# Shift so that tokens < n predict n
|
||||||
@@ -902,6 +908,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
|||||||
hidden_states = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
|
hidden_states = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
|
||||||
if self.transformer.output_attentions:
|
if self.transformer.output_attentions:
|
||||||
all_attentions, hidden_states = hidden_states
|
all_attentions, hidden_states = hidden_states
|
||||||
|
hidden_states = hidden_states[-1]
|
||||||
|
|
||||||
lm_logits = self.lm_head(hidden_states)
|
lm_logits = self.lm_head(hidden_states)
|
||||||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
|
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
|
||||||
losses = []
|
losses = []
|
||||||
|
|||||||
@@ -115,8 +115,9 @@ class GPT2ModelTest(unittest.TestCase):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def check_gpt2_model_output(self, result):
|
def check_gpt2_model_output(self, result):
|
||||||
|
self.parent.assertEqual(len(result["hidden_states"]), self.n_layer + 1)
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["hidden_states"].size()),
|
list(result["hidden_states"][0].size()),
|
||||||
[self.batch_size, self.n_choices, self.seq_length, self.n_embd])
|
[self.batch_size, self.n_choices, self.seq_length, self.n_embd])
|
||||||
|
|
||||||
|
|
||||||
@@ -222,7 +223,10 @@ class GPT2ModelTest(unittest.TestCase):
|
|||||||
else:
|
else:
|
||||||
output = model(input_ids, head_mask=head_mask)
|
output = model(input_ids, head_mask=head_mask)
|
||||||
|
|
||||||
output = sum(t.sum() for t in output[:-1])
|
if isinstance(model, GPT2Model):
|
||||||
|
output = sum(t.sum() for t in output[0])
|
||||||
|
elif isinstance(output, (list, tuple)):
|
||||||
|
output = sum(t.sum() for t in output[:-1])
|
||||||
output = output.sum()
|
output = output.sum()
|
||||||
output.backward()
|
output.backward()
|
||||||
multihead_outputs = (model if isinstance(model, GPT2Model) else model.transformer).get_multihead_outputs()
|
multihead_outputs = (model if isinstance(model, GPT2Model) else model.transformer).get_multihead_outputs()
|
||||||
@@ -256,7 +260,10 @@ class GPT2ModelTest(unittest.TestCase):
|
|||||||
else:
|
else:
|
||||||
output = model(input_ids)
|
output = model(input_ids)
|
||||||
|
|
||||||
output = sum(t.sum() for t in output[:-1])
|
if isinstance(model, GPT2Model):
|
||||||
|
output = sum(t.sum() for t in output[0])
|
||||||
|
elif isinstance(output, (list, tuple)):
|
||||||
|
output = sum(t.sum() for t in output[:-1])
|
||||||
output = output.sum()
|
output = output.sum()
|
||||||
output.backward()
|
output.backward()
|
||||||
multihead_outputs = transformer.get_multihead_outputs()
|
multihead_outputs = transformer.get_multihead_outputs()
|
||||||
|
|||||||
@@ -125,8 +125,9 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def check_openai_model_output(self, result):
|
def check_openai_model_output(self, result):
|
||||||
|
self.parent.assertEqual(len(result["hidden_states"]), self.n_layer + 1)
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["hidden_states"].size()),
|
list(result["hidden_states"][0].size()),
|
||||||
[self.batch_size, self.n_choices, self.seq_length, self.n_embd])
|
[self.batch_size, self.n_choices, self.seq_length, self.n_embd])
|
||||||
|
|
||||||
|
|
||||||
@@ -195,7 +196,10 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
|||||||
else:
|
else:
|
||||||
output = model(input_ids, head_mask=head_mask)
|
output = model(input_ids, head_mask=head_mask)
|
||||||
|
|
||||||
output = sum(t.sum() for t in output[:-1])
|
if isinstance(model, OpenAIGPTModel):
|
||||||
|
output = sum(t.sum() for t in output[0])
|
||||||
|
elif isinstance(output, (list, tuple)):
|
||||||
|
output = sum(t.sum() for t in output)
|
||||||
output = output.sum()
|
output = output.sum()
|
||||||
output.backward()
|
output.backward()
|
||||||
multihead_outputs = (model if isinstance(model, OpenAIGPTModel) else model.transformer).get_multihead_outputs()
|
multihead_outputs = (model if isinstance(model, OpenAIGPTModel) else model.transformer).get_multihead_outputs()
|
||||||
@@ -229,7 +233,10 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
|||||||
else:
|
else:
|
||||||
output = model(input_ids)
|
output = model(input_ids)
|
||||||
|
|
||||||
output = sum(t.sum() for t in output[:-1])
|
if isinstance(model, OpenAIGPTModel):
|
||||||
|
output = sum(t.sum() for t in output[0])
|
||||||
|
elif isinstance(output, (list, tuple)):
|
||||||
|
output = sum(t.sum() for t in output)
|
||||||
output = output.sum()
|
output = output.sum()
|
||||||
output.backward()
|
output.backward()
|
||||||
multihead_outputs = transformer.get_multihead_outputs()
|
multihead_outputs = transformer.get_multihead_outputs()
|
||||||
|
|||||||
Reference in New Issue
Block a user