[TFxxxxForSequenceClassifciation] Fix the eager mode after #25085 (#25751)

* TODOS

* Switch .shape -> shape_list

---------

Co-authored-by: Matt <rocketknight1@gmail.com>
This commit is contained in:
Arthur
2023-10-24 14:33:05 +02:00
committed by GitHub
parent e2d6d5ce57
commit 7bde5d634f

View File

@@ -870,7 +870,11 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific
tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1) tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
- 1 - 1
) )
sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1) sequence_lengths = tf.where(
sequence_lengths >= 0,
sequence_lengths,
tf.cast(shape_list(input_ids[-1]), sequence_lengths.dtype) - 1,
)
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1) in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
else: else:
sequence_lengths = -1 sequence_lengths = -1