[Generation, EncoderDecoder] Apply Encoder Decoder 1.5GB memory… (#3778)
This commit is contained in:
committed by
GitHub
parent
352d5472b0
commit
092cf881a5
@@ -704,6 +704,21 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
effective_batch_size = batch_size
|
effective_batch_size = batch_size
|
||||||
effective_batch_mult = 1
|
effective_batch_mult = 1
|
||||||
|
|
||||||
|
if self.config.is_encoder_decoder:
|
||||||
|
if decoder_start_token_id is None:
|
||||||
|
decoder_start_token_id = bos_token_id
|
||||||
|
|
||||||
|
assert (
|
||||||
|
decoder_start_token_id is not None
|
||||||
|
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
|
||||||
|
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
|
||||||
|
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
|
||||||
|
|
||||||
|
# get encoder and store encoder outputs
|
||||||
|
encoder = self.get_encoder()
|
||||||
|
|
||||||
|
encoder_outputs = encoder(input_ids, attention_mask=attention_mask)
|
||||||
|
|
||||||
# Expand input ids if num_beams > 1 or num_return_sequences > 1
|
# Expand input ids if num_beams > 1 or num_return_sequences > 1
|
||||||
if num_return_sequences > 1 or num_beams > 1:
|
if num_return_sequences > 1 or num_beams > 1:
|
||||||
input_ids_len = shape_list(input_ids)[-1]
|
input_ids_len = shape_list(input_ids)[-1]
|
||||||
@@ -721,24 +736,23 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||||||
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
||||||
|
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
if decoder_start_token_id is None:
|
|
||||||
decoder_start_token_id = bos_token_id
|
|
||||||
|
|
||||||
assert (
|
|
||||||
decoder_start_token_id is not None
|
|
||||||
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
|
|
||||||
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
|
|
||||||
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
|
|
||||||
|
|
||||||
# get encoder and store encoder outputs
|
|
||||||
encoder = self.get_encoder()
|
|
||||||
|
|
||||||
encoder_outputs = encoder(input_ids, attention_mask=attention_mask)
|
|
||||||
|
|
||||||
# create empty decoder_input_ids
|
# create empty decoder_input_ids
|
||||||
input_ids = tf.ones((effective_batch_size * num_beams, 1), dtype=tf.int32,) * decoder_start_token_id
|
input_ids = tf.ones((effective_batch_size * num_beams, 1), dtype=tf.int32,) * decoder_start_token_id
|
||||||
cur_len = 1
|
cur_len = 1
|
||||||
|
|
||||||
|
assert (
|
||||||
|
batch_size == encoder_outputs[0].shape[0]
|
||||||
|
), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "
|
||||||
|
|
||||||
|
# expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
|
||||||
|
expanded_batch_idxs = tf.reshape(
|
||||||
|
tf.repeat(tf.expand_dims(tf.range(batch_size), -1), repeats=num_beams * effective_batch_mult, axis=1),
|
||||||
|
shape=(-1,),
|
||||||
|
)
|
||||||
|
# expand encoder_outputs
|
||||||
|
encoder_outputs = (tf.gather(encoder_outputs[0], expanded_batch_idxs, axis=0), *encoder_outputs[1:])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
encoder_outputs = None
|
encoder_outputs = None
|
||||||
cur_len = shape_list(input_ids)[-1]
|
cur_len = shape_list(input_ids)[-1]
|
||||||
|
|||||||
Reference in New Issue
Block a user