Update serving signatures and make sure we actually use them (#19034)
* Override save() to use the serving signature as the default * Replace int32 with int64 in all our serving signatures * Remember one very important line so as not to break every test at once * Dtype fix for TFLED * dtype fix for shift_tokens_right in general * Dtype fixes in mBART and RAG * Fix dtypes for test_unpack_inputs * More dtype fixes * Yet more mBART + RAG dtype fixes * Yet more mBART + RAG dtype fixes * Add a check that the model actually has a serving method
This commit is contained in:
@@ -1685,16 +1685,21 @@ _TOKENIZER_FOR_DOC = "{{cookiecutter.camelcase_modelname}}Tokenizer"
|
||||
LARGE_NEGATIVE = -1e8
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
|
||||
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
||||
start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
|
||||
pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
|
||||
decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
|
||||
start_tokens = tf.fill((shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype))
|
||||
shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
|
||||
# replace possible -100 values in labels by `pad_token_id`
|
||||
shifted_input_ids = tf.where(
|
||||
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
|
||||
shifted_input_ids == -100,
|
||||
tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)),
|
||||
shifted_input_ids,
|
||||
)
|
||||
|
||||
# "Verify that `labels` has only positive values and -100"
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0))
|
||||
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=shifted_input_ids.dtype))
|
||||
|
||||
# Make sure the assertion op is called by wrapping the result in an identity no-op
|
||||
with tf.control_dependencies([assert_gte0]):
|
||||
|
||||
Reference in New Issue
Block a user