change shape to support dynamic batch input in tf.function XLA generate for tf serving (#18372)
* change shape to support dynamic batch input in tf.generate * add tests Co-authored-by: nlpcatcode <nlpcodecat@gmail.com>
This commit is contained in:
@@ -1533,7 +1533,7 @@ class TFGenerationMixin:
|
|||||||
# 2. Define model inputs
|
# 2. Define model inputs
|
||||||
input_ids = self._prepare_model_inputs(input_ids, bos_token_id)
|
input_ids = self._prepare_model_inputs(input_ids, bos_token_id)
|
||||||
# inputs_ids now has to be defined and cannot be None anymore
|
# inputs_ids now has to be defined and cannot be None anymore
|
||||||
batch_size = input_ids.shape[0]
|
batch_size = shape_list(input_ids)[0]
|
||||||
|
|
||||||
# 3. Prepare other model kwargs
|
# 3. Prepare other model kwargs
|
||||||
if output_attentions is not None:
|
if output_attentions is not None:
|
||||||
@@ -1702,7 +1702,8 @@ class TFGenerationMixin:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _expand_to_num_beams(tensor: tf.Tensor, num_beams: int) -> tf.Tensor:
|
def _expand_to_num_beams(tensor: tf.Tensor, num_beams: int) -> tf.Tensor:
|
||||||
return tf.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:])
|
shape = shape_list(tensor)
|
||||||
|
return tf.broadcast_to(tensor[:, None], (shape[0], num_beams) + tuple(shape[1:]))
|
||||||
|
|
||||||
def _prepare_attention_mask_for_generation(
|
def _prepare_attention_mask_for_generation(
|
||||||
self,
|
self,
|
||||||
@@ -2162,7 +2163,7 @@ class TFGenerationMixin:
|
|||||||
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
|
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
|
||||||
|
|
||||||
# 3. init tensors to use for "xla-compileable" generate function
|
# 3. init tensors to use for "xla-compileable" generate function
|
||||||
batch_size, cur_len = input_ids.shape
|
batch_size, cur_len = shape_list(input_ids)
|
||||||
|
|
||||||
# initialize `generated` (`input_ids` padded with `pad_token_id`), `finished_sequences`
|
# initialize `generated` (`input_ids` padded with `pad_token_id`), `finished_sequences`
|
||||||
input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0)
|
input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0)
|
||||||
@@ -2432,7 +2433,7 @@ class TFGenerationMixin:
|
|||||||
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
|
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
|
||||||
|
|
||||||
# 3. init tensors to use for "xla-compileable" generate function
|
# 3. init tensors to use for "xla-compileable" generate function
|
||||||
batch_size, cur_len = input_ids.shape
|
batch_size, cur_len = shape_list(input_ids)
|
||||||
|
|
||||||
# initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences`
|
# initialize `generated` (pre-populated with `pad_token_id`), `finished_sequences`
|
||||||
input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0)
|
input_ids_padding = tf.ones((batch_size, max_length - cur_len), dtype=tf.int32) * (pad_token_id or 0)
|
||||||
@@ -2678,18 +2679,16 @@ class TFGenerationMixin:
|
|||||||
|
|
||||||
def flatten_beam_dim(tensor, batch_axis=0):
|
def flatten_beam_dim(tensor, batch_axis=0):
|
||||||
"""Flattens the first two dimensions of a non-scalar array."""
|
"""Flattens the first two dimensions of a non-scalar array."""
|
||||||
|
shape = shape_list(tensor)
|
||||||
return tf.reshape(
|
return tf.reshape(
|
||||||
tensor,
|
tensor,
|
||||||
tensor.shape[:batch_axis]
|
shape[:batch_axis] + [shape[batch_axis] * shape[batch_axis + 1]] + shape[batch_axis + 2 :],
|
||||||
+ [tensor.shape[batch_axis] * tensor.shape[batch_axis + 1]]
|
|
||||||
+ tensor.shape[batch_axis + 2 :],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def unflatten_beam_dim(tensor, batch_size, num_beams, batch_axis=0):
|
def unflatten_beam_dim(tensor, batch_size, num_beams, batch_axis=0):
|
||||||
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
|
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
|
||||||
return tf.reshape(
|
shape = shape_list(tensor)
|
||||||
tensor, tensor.shape[:batch_axis] + [batch_size, num_beams] + tensor.shape[batch_axis + 1 :]
|
return tf.reshape(tensor, shape[:batch_axis] + [batch_size, num_beams] + shape[batch_axis + 1 :])
|
||||||
)
|
|
||||||
|
|
||||||
def gather_beams(nested, beam_indices, batch_axis=0):
|
def gather_beams(nested, beam_indices, batch_axis=0):
|
||||||
"""Gathers the beam slices indexed by beam_indices into new beam array."""
|
"""Gathers the beam slices indexed by beam_indices into new beam array."""
|
||||||
@@ -2748,7 +2747,7 @@ class TFGenerationMixin:
|
|||||||
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
|
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
|
||||||
|
|
||||||
# 3. init tensors to use for "xla-compileable" generate function
|
# 3. init tensors to use for "xla-compileable" generate function
|
||||||
batch_size, num_beams, cur_len = input_ids.shape
|
batch_size, num_beams, cur_len = shape_list(input_ids)
|
||||||
|
|
||||||
# per batch, beam-item holding current token in loop, pre-populated with `pad_token_id`
|
# per batch, beam-item holding current token in loop, pre-populated with `pad_token_id`
|
||||||
input_ids_padding = tf.ones((batch_size, num_beams, max_length - cur_len), dtype=tf.int32) * (
|
input_ids_padding = tf.ones((batch_size, num_beams, max_length - cur_len), dtype=tf.int32) * (
|
||||||
@@ -2894,7 +2893,7 @@ class TFGenerationMixin:
|
|||||||
eos_in_next_token = tf.broadcast_to(eos_in_next_token, topk_sequences[:, :, cur_len].shape)
|
eos_in_next_token = tf.broadcast_to(eos_in_next_token, topk_sequences[:, :, cur_len].shape)
|
||||||
did_topk_just_finished = eos_in_next_token & tf.broadcast_to(
|
did_topk_just_finished = eos_in_next_token & tf.broadcast_to(
|
||||||
tf.concat((tf.ones((num_beams), dtype=tf.bool), tf.zeros((num_beams), dtype=tf.bool)), axis=0),
|
tf.concat((tf.ones((num_beams), dtype=tf.bool), tf.zeros((num_beams), dtype=tf.bool)), axis=0),
|
||||||
eos_in_next_token.shape,
|
shape_list(eos_in_next_token),
|
||||||
)
|
)
|
||||||
|
|
||||||
# non-top `num_beams` eos tokens can't be used to finish a beam, but the others can't be used in the next
|
# non-top `num_beams` eos tokens can't be used to finish a beam, but the others can't be used in the next
|
||||||
@@ -2917,7 +2916,7 @@ class TFGenerationMixin:
|
|||||||
topk_log_probs = topk_log_probs / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty)
|
topk_log_probs = topk_log_probs / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty)
|
||||||
beams_in_batch_are_full = (
|
beams_in_batch_are_full = (
|
||||||
tf.broadcast_to(
|
tf.broadcast_to(
|
||||||
tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), did_topk_just_finished.shape
|
tf.math.reduce_all(is_sent_finished, axis=-1, keepdims=True), shape_list(did_topk_just_finished)
|
||||||
)
|
)
|
||||||
& early_stopping
|
& early_stopping
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ if is_tf_available():
|
|||||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
BertConfig,
|
BertConfig,
|
||||||
TFAutoModel,
|
TFAutoModel,
|
||||||
|
TFAutoModelForSeq2SeqLM,
|
||||||
TFAutoModelForSequenceClassification,
|
TFAutoModelForSequenceClassification,
|
||||||
TFBertModel,
|
TFBertModel,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
@@ -2163,6 +2164,46 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||||||
for p1, p2 in zip(model.weights, new_model.weights):
|
for p1, p2 in zip(model.weights, new_model.weights):
|
||||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||||
|
|
||||||
|
def test_generate_tf_function_export(self):
|
||||||
|
test_model = TFAutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||||
|
max_length = 8
|
||||||
|
|
||||||
|
class DummyModel(tf.Module):
|
||||||
|
def __init__(self, model):
|
||||||
|
super(DummyModel, self).__init__()
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
@tf.function(
|
||||||
|
input_signature=(
|
||||||
|
tf.TensorSpec((None, max_length), tf.int32, name="input_ids"),
|
||||||
|
tf.TensorSpec((None, max_length), tf.int32, name="attention_mask"),
|
||||||
|
),
|
||||||
|
jit_compile=True,
|
||||||
|
)
|
||||||
|
def serving(self, input_ids, attention_mask):
|
||||||
|
outputs = self.model.generate(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
max_new_tokens=max_length,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
)
|
||||||
|
return {"sequences": outputs["sequences"]}
|
||||||
|
|
||||||
|
dummy_input_ids = [[2, 3, 4, 1, 0, 0, 0, 0], [102, 103, 104, 105, 1, 0, 0, 0]]
|
||||||
|
dummy_attention_masks = [[1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 0, 0, 0]]
|
||||||
|
dummy_model = DummyModel(model=test_model)
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
tf.saved_model.save(dummy_model, tmp_dir, signatures={"serving_default": dummy_model.serving})
|
||||||
|
serving_func = tf.saved_model.load(tmp_dir).signatures["serving_default"]
|
||||||
|
for batch_size in range(1, len(dummy_input_ids) + 1):
|
||||||
|
inputs = {
|
||||||
|
"input_ids": tf.constant(dummy_input_ids[:batch_size]),
|
||||||
|
"attention_mask": tf.constant(dummy_attention_masks[:batch_size]),
|
||||||
|
}
|
||||||
|
tf_func_outputs = serving_func(**inputs)["sequences"]
|
||||||
|
tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_length)
|
||||||
|
tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs)
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
|
|||||||
Reference in New Issue
Block a user