diff --git a/src/transformers/models/gptj/modeling_tf_gptj.py b/src/transformers/models/gptj/modeling_tf_gptj.py index f215adaaac..f5080f674c 100644 --- a/src/transformers/models/gptj/modeling_tf_gptj.py +++ b/src/transformers/models/gptj/modeling_tf_gptj.py @@ -870,7 +870,11 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1) - 1 ) - sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1) + sequence_lengths = tf.where( + sequence_lengths >= 0, + sequence_lengths, + tf.cast(shape_list(input_ids[-1]), sequence_lengths.dtype) - 1, + ) in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1) else: sequence_lengths = -1