* TODOS * Switch .shape -> shape_list --------- Co-authored-by: Matt <rocketknight1@gmail.com>
This commit is contained in:
@@ -870,7 +870,11 @@ class TFGPTJForSequenceClassification(TFGPTJPreTrainedModel, TFSequenceClassific
|
|||||||
tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
|
tf.argmax(tf.cast(tf.math.equal(input_ids, self.config.pad_token_id), input_ids.dtype), axis=-1)
|
||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
sequence_lengths = tf.where(sequence_lengths >= 0, sequence_lengths, input_ids.shape[-1] - 1)
|
sequence_lengths = tf.where(
|
||||||
|
sequence_lengths >= 0,
|
||||||
|
sequence_lengths,
|
||||||
|
tf.cast(shape_list(input_ids[-1]), sequence_lengths.dtype) - 1,
|
||||||
|
)
|
||||||
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
|
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
|
||||||
else:
|
else:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
|
|||||||
Reference in New Issue
Block a user