From 092cf881a57de264e9a2798e067c92f9bd5edfc1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 14 Apr 2020 04:29:28 +0200 Subject: [PATCH] =?UTF-8?q?[Generation,=20EncoderDecoder]=20Apply=20Encode?= =?UTF-8?q?r=20Decoder=201.5GB=20memory=E2=80=A6=20(#3778)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/transformers/modeling_tf_utils.py | 40 ++++++++++++++++++--------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index fc75984c8b..e43d745e3b 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -704,6 +704,21 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): effective_batch_size = batch_size effective_batch_mult = 1 + if self.config.is_encoder_decoder: + if decoder_start_token_id is None: + decoder_start_token_id = bos_token_id + + assert ( + decoder_start_token_id is not None + ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" + assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self) + assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder) + + # get encoder and store encoder outputs + encoder = self.get_encoder() + + encoder_outputs = encoder(input_ids, attention_mask=attention_mask) + # Expand input ids if num_beams > 1 or num_return_sequences > 1 if num_return_sequences > 1 or num_beams > 1: input_ids_len = shape_list(input_ids)[-1] @@ -721,24 +736,23 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) if self.config.is_encoder_decoder: - if decoder_start_token_id is None: - decoder_start_token_id = bos_token_id - - assert ( - decoder_start_token_id is not None - ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" - assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self) - assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder) - - # get encoder and store encoder outputs - encoder = self.get_encoder() - - encoder_outputs = encoder(input_ids, attention_mask=attention_mask) # create empty decoder_input_ids input_ids = tf.ones((effective_batch_size * num_beams, 1), dtype=tf.int32,) * decoder_start_token_id cur_len = 1 + assert ( + batch_size == encoder_outputs[0].shape[0] + ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} " + + # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1) + expanded_batch_idxs = tf.reshape( + tf.repeat(tf.expand_dims(tf.range(batch_size), -1), repeats=num_beams * effective_batch_mult, axis=1), + shape=(-1,), + ) + # expand encoder_outputs + encoder_outputs = (tf.gather(encoder_outputs[0], expanded_batch_idxs, axis=0), *encoder_outputs[1:]) + else: encoder_outputs = None cur_len = shape_list(input_ids)[-1]