* TODOS * Switch .shape -> shape_list --------- Co-authored-by: Matt <rocketknight1@gmail.com>
This commit is contained in:
@@ -870,7 +870,11 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific
|
||||
tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
|
||||
- 1
|
||||
)
|
||||
sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
|
||||
sequence_lengths = tf.where(
|
||||
sequence_lengths >= 0,
|
||||
sequence_lengths,
|
||||
tf.cast(shape_list(input_ids[-1]), sequence_lengths.dtype) - 1,
|
||||
)
|
||||
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
|
||||
else:
|
||||
sequence_lengths = -1
|
||||
|
||||
Reference in New Issue
Block a user