fix tf led pt test (#9513)
This commit is contained in:
committed by
GitHub
parent
1e3c362235
commit
6c8ec2a931
@@ -166,7 +166,13 @@ def prepare_led_inputs_dict(
|
|||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
|
||||||
if decoder_attention_mask is None:
|
if decoder_attention_mask is None:
|
||||||
decoder_attention_mask = tf.cast(tf.math.not_equal(decoder_input_ids, config.pad_token_id), tf.int8)
|
decoder_attention_mask = tf.concat(
|
||||||
|
[
|
||||||
|
tf.ones(decoder_input_ids[:, :1].shape, dtype=tf.int8),
|
||||||
|
tf.cast(tf.math.not_equal(decoder_input_ids[:, 1:], config.pad_token_id), tf.int8),
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
|
|||||||
Reference in New Issue
Block a user