Fix _shift_right function in TFT5PreTrainedModel (#6214)
This commit is contained in:
committed by
GitHub
parent
0b41867357
commit
06f1692b02
@@ -783,8 +783,7 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
|
|||||||
decoder_start_token_id is not None
|
decoder_start_token_id is not None
|
||||||
), "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information"
|
), "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information"
|
||||||
|
|
||||||
# shift inputs to the right
|
shifted_input_ids = tf.cast(input_ids, tf.int32)
|
||||||
shifted_input_ids = tf.zeros_like(input_ids, dtype=tf.int32)
|
|
||||||
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
|
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
|
||||||
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id)
|
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id)
|
||||||
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
|
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
|
||||||
@@ -795,9 +794,12 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
|
|||||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
assert tf.math.reduce_any(
|
# "Verify that `labels` has only positive values and -100"
|
||||||
shifted_input_ids >= 0
|
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32))
|
||||||
).numpy(), "Verify that `labels` has only positive values and -100"
|
|
||||||
|
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||||
|
with tf.control_dependencies([assert_gte0]):
|
||||||
|
shifted_input_ids = tf.identity(shifted_input_ids)
|
||||||
|
|
||||||
return shifted_input_ids
|
return shifted_input_ids
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user