Fix PT-TF equivalence test for GPT1 (#22586)
* Re-enable skipped test and fix the hidden state shape issue * Actually fix the bug instead of just doing something wrong
This commit is contained in:
@@ -748,6 +748,12 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
|
|||||||
)
|
)
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
|
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
|
||||||
|
if return_dict and output_hidden_states:
|
||||||
|
# We do this to match the slightly odd PT behaviour - the final hidden state is reshaped to rank 4 when the
|
||||||
|
# input is rank 3, but all other hidden states remain at rank-3 (with the first 2 dims merged)
|
||||||
|
all_hidden_states = transformer_outputs.hidden_states[:-1] + (hidden_states,)
|
||||||
|
else:
|
||||||
|
all_hidden_states = None
|
||||||
lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear")
|
lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear")
|
||||||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training)
|
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training)
|
||||||
mc_logits = tf.squeeze(mc_logits, axis=-1)
|
mc_logits = tf.squeeze(mc_logits, axis=-1)
|
||||||
@@ -758,7 +764,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
|
|||||||
return TFOpenAIGPTDoubleHeadsModelOutput(
|
return TFOpenAIGPTDoubleHeadsModelOutput(
|
||||||
logits=lm_logits,
|
logits=lm_logits,
|
||||||
mc_logits=mc_logits,
|
mc_logits=mc_logits,
|
||||||
hidden_states=transformer_outputs.hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=transformer_outputs.attentions,
|
attentions=transformer_outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -274,10 +274,6 @@ class OpenAIGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
model = OpenAIGPTModel.from_pretrained(model_name)
|
model = OpenAIGPTModel.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
@unittest.skip("Fix me Matt")
|
|
||||||
def test_pt_tf_model_equivalence(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class OPENAIGPTModelLanguageGenerationTest(unittest.TestCase):
|
class OPENAIGPTModelLanguageGenerationTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user