Generate: TF .generate() can now be exported with dynamic length (#21474)
This commit is contained in:
@@ -849,7 +849,7 @@ class TFGenerationMixin:
|
|||||||
input_ids = inputs_tensor
|
input_ids = inputs_tensor
|
||||||
|
|
||||||
# 7. Prepare `max_length` depending on other stopping criteria.
|
# 7. Prepare `max_length` depending on other stopping criteria.
|
||||||
input_ids_seq_length = input_ids.shape[-1]
|
input_ids_seq_length = shape_list(input_ids)[-1]
|
||||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||||
if has_default_max_length and generation_config.max_new_tokens is None:
|
if has_default_max_length and generation_config.max_new_tokens is None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
@@ -869,18 +869,23 @@ class TFGenerationMixin:
|
|||||||
UserWarning,
|
UserWarning,
|
||||||
)
|
)
|
||||||
|
|
||||||
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
|
# If the input length is a tensor (i.e. dynamic length), skip length checks
|
||||||
raise ValueError(
|
if not isinstance(input_ids_seq_length, tf.Tensor):
|
||||||
f"Unfeasable length constraints: the minimum length ({generation_config.min_length}) is larger than"
|
if (
|
||||||
f" the maximum length ({generation_config.max_length})"
|
generation_config.min_length is not None
|
||||||
)
|
and generation_config.min_length > generation_config.max_length
|
||||||
if input_ids_seq_length >= generation_config.max_length:
|
):
|
||||||
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
raise ValueError(
|
||||||
logger.warning(
|
f"Unfeasable length constraints: the minimum length ({generation_config.min_length}) is larger"
|
||||||
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
f" than the maximum length ({generation_config.max_length})"
|
||||||
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
)
|
||||||
" increasing`max_new_tokens`."
|
if input_ids_seq_length >= generation_config.max_length:
|
||||||
)
|
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
||||||
|
logger.warning(
|
||||||
|
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
||||||
|
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
||||||
|
" increasing`max_new_tokens`."
|
||||||
|
)
|
||||||
|
|
||||||
# 8. determine generation mode
|
# 8. determine generation mode
|
||||||
is_contrastive_search_gen_mode = (
|
is_contrastive_search_gen_mode = (
|
||||||
|
|||||||
@@ -182,7 +182,7 @@ class TFAttention(tf.keras.layers.Layer):
|
|||||||
key = self.split_heads(key)
|
key = self.split_heads(key)
|
||||||
value = self.split_heads(value)
|
value = self.split_heads(value)
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
past_key, past_value = tf.unstack(layer_past, axis=0)
|
past_key, past_value = tf.unstack(layer_past, axis=0, num=2)
|
||||||
key = tf.concat([past_key, key], axis=-2)
|
key = tf.concat([past_key, key], axis=-2)
|
||||||
value = tf.concat([past_value, value], axis=-2)
|
value = tf.concat([past_value, value], axis=-2)
|
||||||
|
|
||||||
|
|||||||
@@ -144,9 +144,10 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
|
|||||||
}
|
}
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_generate_tf_function_export(self):
|
def test_generate_tf_function_export_fixed_input_length(self):
|
||||||
test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
max_length = 2
|
input_length = 2
|
||||||
|
max_new_tokens = 2
|
||||||
|
|
||||||
class DummyModel(tf.Module):
|
class DummyModel(tf.Module):
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
@@ -155,8 +156,8 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
|
|||||||
|
|
||||||
@tf.function(
|
@tf.function(
|
||||||
input_signature=(
|
input_signature=(
|
||||||
tf.TensorSpec((None, max_length), tf.int32, name="input_ids"),
|
tf.TensorSpec((None, input_length), tf.int32, name="input_ids"),
|
||||||
tf.TensorSpec((None, max_length), tf.int32, name="attention_mask"),
|
tf.TensorSpec((None, input_length), tf.int32, name="attention_mask"),
|
||||||
),
|
),
|
||||||
jit_compile=True,
|
jit_compile=True,
|
||||||
)
|
)
|
||||||
@@ -164,7 +165,7 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
|
|||||||
outputs = self.model.generate(
|
outputs = self.model.generate(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
max_new_tokens=max_length,
|
max_new_tokens=max_new_tokens,
|
||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
)
|
)
|
||||||
return {"sequences": outputs["sequences"]}
|
return {"sequences": outputs["sequences"]}
|
||||||
@@ -181,5 +182,47 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
|
|||||||
"attention_mask": tf.constant(dummy_attention_masks[:batch_size]),
|
"attention_mask": tf.constant(dummy_attention_masks[:batch_size]),
|
||||||
}
|
}
|
||||||
tf_func_outputs = serving_func(**inputs)["sequences"]
|
tf_func_outputs = serving_func(**inputs)["sequences"]
|
||||||
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_length)
|
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||||
|
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_generate_tf_function_export_fixed_batch_size(self):
|
||||||
|
test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
|
batch_size = 1
|
||||||
|
max_new_tokens = 2
|
||||||
|
|
||||||
|
class DummyModel(tf.Module):
|
||||||
|
def __init__(self, model):
|
||||||
|
super(DummyModel, self).__init__()
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
@tf.function(
|
||||||
|
input_signature=(
|
||||||
|
tf.TensorSpec((batch_size, None), tf.int32, name="input_ids"),
|
||||||
|
tf.TensorSpec((batch_size, None), tf.int32, name="attention_mask"),
|
||||||
|
),
|
||||||
|
jit_compile=True,
|
||||||
|
)
|
||||||
|
def serving(self, input_ids, attention_mask):
|
||||||
|
outputs = self.model.generate(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
)
|
||||||
|
return {"sequences": outputs["sequences"]}
|
||||||
|
|
||||||
|
dummy_input_ids = [[2], [102, 103]]
|
||||||
|
dummy_attention_masks = [[1], [1, 1]]
|
||||||
|
dummy_model = DummyModel(model=test_model)
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
tf.saved_model.save(dummy_model, tmp_dir, signatures={"serving_default": dummy_model.serving})
|
||||||
|
serving_func = tf.saved_model.load(tmp_dir).signatures["serving_default"]
|
||||||
|
for input_row in range(len(dummy_input_ids)):
|
||||||
|
inputs = {
|
||||||
|
"input_ids": tf.constant([dummy_input_ids[input_row]]),
|
||||||
|
"attention_mask": tf.constant([dummy_attention_masks[input_row]]),
|
||||||
|
}
|
||||||
|
tf_func_outputs = serving_func(**inputs)["sequences"]
|
||||||
|
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||||
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)
|
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user