From 2a91a9ef663776ad8259ff22fd285f3cfc888d0f Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 5 Apr 2023 13:16:00 +0100 Subject: [PATCH] Fix PT-TF equivalence test for GPT1 (#22586) * Re-enable skipped test and fix the hidden state shape issue * Actually fix the bug instead of just doing something wrong --- src/transformers/models/openai/modeling_tf_openai.py | 8 +++++++- tests/models/openai/test_modeling_openai.py | 4 ---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/openai/modeling_tf_openai.py b/src/transformers/models/openai/modeling_tf_openai.py index 4bd4f506e9..5723001729 100644 --- a/src/transformers/models/openai/modeling_tf_openai.py +++ b/src/transformers/models/openai/modeling_tf_openai.py @@ -748,6 +748,12 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel): ) hidden_states = transformer_outputs[0] hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:]) + if return_dict and output_hidden_states: + # We do this to match the slightly odd PT behaviour - the final hidden state is reshaped to rank 4 when the + # input is rank 3, but all other hidden states remain at rank-3 (with the first 2 dims merged) + all_hidden_states = transformer_outputs.hidden_states[:-1] + (hidden_states,) + else: + all_hidden_states = None lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear") mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training) mc_logits = tf.squeeze(mc_logits, axis=-1) @@ -758,7 +764,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel): return TFOpenAIGPTDoubleHeadsModelOutput( logits=lm_logits, mc_logits=mc_logits, - hidden_states=transformer_outputs.hidden_states, + hidden_states=all_hidden_states, attentions=transformer_outputs.attentions, ) diff --git a/tests/models/openai/test_modeling_openai.py b/tests/models/openai/test_modeling_openai.py index 77d64c100a..0e8ba6d9ce 100644 --- a/tests/models/openai/test_modeling_openai.py +++ b/tests/models/openai/test_modeling_openai.py @@ -274,10 +274,6 @@ class OpenAIGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester model = OpenAIGPTModel.from_pretrained(model_name) self.assertIsNotNone(model) - @unittest.skip("Fix me Matt") - def test_pt_tf_model_equivalence(self): - pass - @require_torch class OPENAIGPTModelLanguageGenerationTest(unittest.TestCase):