From 7bde5d634f535206e0bf4e279b32d65d1f7568f7 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 24 Oct 2023 14:33:05 +0200 Subject: [PATCH] [`TFxxxxForSequenceClassifciation`] Fix the eager mode after #25085 (#25751) * TODOS * Switch .shape -> shape_list --------- Co-authored-by: Matt --- src/transformers/models/gptj/modeling_tf_gptj.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gptj/modeling_tf_gptj.py b/src/transformers/models/gptj/modeling_tf_gptj.py index f215adaaac..f5080f674c 100644 --- a/src/transformers/models/gptj/modeling_tf_gptj.py +++ b/src/transformers/models/gptj/modeling_tf_gptj.py @@ -870,7 +870,11 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1) - 1 ) - sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1) + sequence_lengths = tf.where( + sequence_lengths >= 0, + sequence_lengths, + tf.cast(shape_list(input_ids[-1]), sequence_lengths.dtype) - 1, + ) in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1) else: sequence_lengths = -1