From 38f7461df3fe51308a62a81e4a0e7770a38d7125 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 16 Apr 2020 16:14:52 +0200 Subject: [PATCH] [TFT5, Cache] Add cache to TFT5 (#3772) * correct gpt2 test inputs * make style * delete modeling_gpt2 change in test file * translate from pytorch * correct tests * fix conflicts * fix conflicts * fix conflicts * fix conflicts * make tensorflow t5 caching work * make style * clean reorder cache * remove unnecessary spaces * fix test --- src/transformers/modeling_t5.py | 20 +- src/transformers/modeling_tf_t5.py | 338 +++++++++++++++++++++----- src/transformers/modeling_tf_utils.py | 12 +- tests/test_modeling_t5.py | 3 +- tests/test_modeling_tf_gpt2.py | 2 +- tests/test_modeling_tf_t5.py | 95 +++++++- 6 files changed, 384 insertions(+), 86 deletions(-) diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index e78db03905..d86dbc4bfb 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -351,12 +351,11 @@ class T5Attention(nn.Module): else: k, v = past_key_value_state - if self.is_decoder and use_cache: + if self.is_decoder and use_cache is True: present_key_value_state = ((k, v),) else: present_key_value_state = (None,) - # q = q / math.sqrt(dim_per_head) # No scaling in T5 scores = torch.einsum("bnqd,bnkd->bnqk", q, k) # (bs, n_heads, qlen, klen) if position_bias is None: @@ -486,11 +485,15 @@ class T5Block(nn.Module): if past_key_value_state is not None: assert self.is_decoder, "Only decoder can use `past_key_value_states`" - assert ( - len(past_key_value_state) == 4 - ), "The should be 4 past states. 2 (past / key) for self attention. 2 (past / key) for cross attention. Got {} past key / value states".format( - len(past_key_value_state) + expected_num_past_key_value_states = 2 if encoder_hidden_states is None else 4 + + error_message = "There should be {} past states. 2 (past / key) for self attention.{} Got {} past key / value states".format( + expected_num_past_key_value_states, + "2 (past / key) for cross attention" if expected_num_past_key_value_states == 4 else "", + len(past_key_value_state), ) + assert len(past_key_value_state) == expected_num_past_key_value_states, error_message + self_attn_past_key_value_state = past_key_value_state[:2] cross_attn_past_key_value_state = past_key_value_state[2:] else: @@ -507,7 +510,7 @@ class T5Block(nn.Module): hidden_states, present_key_value_state = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights - if self.is_decoder: + if self.is_decoder and encoder_hidden_states is not None: # the actual query length is unknown for cross attention # if using past key value states. Need to inject it here if present_key_value_state is not None: @@ -691,7 +694,6 @@ class T5Stack(T5PreTrainedModel): if past_key_value_states is None: past_key_value_states = [None] * len(self.block) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, self.device) @@ -732,7 +734,7 @@ class T5Stack(T5PreTrainedModel): # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) position_bias = layer_outputs[3 if self.output_attentions else 2] - if self.is_decoder: + if self.is_decoder and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 3] # append next layer key value states present_key_value_states = present_key_value_states + (present_key_value_state,) diff --git a/src/transformers/modeling_tf_t5.py b/src/transformers/modeling_tf_t5.py index eede1fd675..e901e8d70d 100644 --- a/src/transformers/modeling_tf_t5.py +++ b/src/transformers/modeling_tf_t5.py @@ -185,16 +185,39 @@ class TFT5Attention(tf.keras.layers.Layer): return values def call( - self, input, mask=None, kv=None, position_bias=None, cache=None, head_mask=None, training=False, + self, + input, + mask=None, + kv=None, + position_bias=None, + cache=None, + past_key_value_state=None, + head_mask=None, + query_length=None, + use_cache=False, + training=False, ): """ Self-attention (if kv is None) or attention over source sentence (provided by kv). """ # Input is (bs, qlen, dim) # Mask is (bs, klen) (non-causal) or (bs, klen, klen) + # past_key_value_state[0] is (bs, n_heads, q_len - 1, dim_per_head) bs, qlen, dim = shape_list(input) + + if past_key_value_state is not None: + assert self.is_decoder is True, "Encoder cannot cache past key value states" + assert ( + len(past_key_value_state) == 2 + ), "past_key_value_state should have 2 past states: keys and values. Got {} past states".format( + len(past_key_value_state) + ) + real_qlen = qlen + shape_list(past_key_value_state[0])[2] if query_length is None else query_length + else: + real_qlen = qlen + if kv is None: - klen = qlen if cache is None else cache["slen"] + qlen + klen = real_qlen else: klen = shape_list(kv)[1] @@ -207,36 +230,51 @@ class TFT5Attention(tf.keras.layers.Layer): return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.inner_dim)) q = shape(self.q(input)) # (bs, n_heads, qlen, dim_per_head) + if kv is None: k = shape(self.k(input)) # (bs, n_heads, qlen, dim_per_head) v = shape(self.v(input)) # (bs, n_heads, qlen, dim_per_head) - elif cache is None or self.layer_id not in cache: + elif past_key_value_state is None: k = v = kv k = shape(self.k(k)) # (bs, n_heads, qlen, dim_per_head) v = shape(self.v(v)) # (bs, n_heads, qlen, dim_per_head) - if cache is not None: - if self.layer_id in cache: - if kv is None: - k_, v_ = cache[self.layer_id] - k = tf.concat([k_, k], axis=2) # (bs, n_heads, klen, dim_per_head) - v = tf.concat([v_, v], axis=2) # (bs, n_heads, klen, dim_per_head) - else: - k, v = cache[self.layer_id] - cache[self.layer_id] = (k, v) + if past_key_value_state is not None: + if kv is None: + k_, v_ = past_key_value_state + k = tf.concat([k_, k], axis=2) # (bs, n_heads, klen, dim_per_head) + v = tf.concat([v_, v], axis=2) # (bs, n_heads, klen, dim_per_head) + else: + k, v = past_key_value_state + + # to cope with keras serialization + # we need to cast `use_cache` to correct bool + # if it is a tensor + if tf.is_tensor(use_cache): + if hasattr(use_cache, "numpy"): + use_cache = bool(use_cache.numpy()) + else: + use_cache = True + + if self.is_decoder and use_cache is True: + present_key_value_state = ((k, v),) + else: + present_key_value_state = (None,) - # q = q / math.sqrt(dim_per_head) # No scaling in T5 - # scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen) scores = tf.einsum("bnqd,bnkd->bnqk", q, k) # (bs, n_heads, qlen, klen) if position_bias is None: if not self.has_relative_attention_bias: raise ValueError("No position_bias provided and no weights to compute position_bias") - position_bias = self.compute_bias(qlen, klen) + position_bias = self.compute_bias(real_qlen, klen) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value_state is not None: + position_bias = position_bias[:, :, -1:, :] + if mask is not None: - position_bias = position_bias + mask - # mask = (mask == 0).expand_as(scores) # (bs, n_heads, qlen, klen) - # scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen) + position_bias = position_bias + mask # (bs, n_heads, qlen, klen) scores += position_bias weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen) @@ -251,7 +289,8 @@ class TFT5Attention(tf.keras.layers.Layer): context = self.o(context) - outputs = (context,) + outputs = (context,) + present_key_value_state + if self.output_attentions: outputs = outputs + (weights,) if self.has_relative_attention_bias: @@ -269,11 +308,24 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer): self.dropout = tf.keras.layers.Dropout(config.dropout_rate) def call( - self, hidden_states, attention_mask=None, position_bias=None, head_mask=None, training=False, + self, + hidden_states, + attention_mask=None, + position_bias=None, + head_mask=None, + past_key_value_state=None, + use_cache=False, + training=False, ): norm_x = self.layer_norm(hidden_states) attention_output = self.SelfAttention( - norm_x, mask=attention_mask, position_bias=position_bias, head_mask=head_mask, training=training, + norm_x, + mask=attention_mask, + position_bias=position_bias, + head_mask=head_mask, + past_key_value_state=past_key_value_state, + use_cache=use_cache, + training=training, ) y = attention_output[0] layer_output = hidden_states + self.dropout(y, training=training) @@ -291,11 +343,28 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer): self.dropout = tf.keras.layers.Dropout(config.dropout_rate) def call( - self, hidden_states, kv, attention_mask=None, position_bias=None, head_mask=None, training=False, + self, + hidden_states, + kv, + attention_mask=None, + position_bias=None, + head_mask=None, + past_key_value_state=None, + query_length=None, + use_cache=False, + training=False, ): norm_x = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( - norm_x, mask=attention_mask, kv=kv, position_bias=position_bias, head_mask=head_mask, training=training, + norm_x, + mask=attention_mask, + kv=kv, + position_bias=position_bias, + head_mask=head_mask, + past_key_value_state=past_key_value_state, + query_length=query_length, + use_cache=use_cache, + training=training, ) y = attention_output[0] layer_output = hidden_states + self.dropout(y, training=training) @@ -317,9 +386,8 @@ class TFT5Block(tf.keras.layers.Layer): config, has_relative_attention_bias=has_relative_attention_bias, name="layer_._1", ) ) - self.layer.append(TFT5LayerFF(config, name="layer_._2")) - else: - self.layer.append(TFT5LayerFF(config, name="layer_._1")) + + self.layer.append(TFT5LayerFF(config, name="layer_._{}".format(len(self.layer)))) def call( self, @@ -330,35 +398,73 @@ class TFT5Block(tf.keras.layers.Layer): encoder_attention_mask=None, encoder_decoder_position_bias=None, head_mask=None, + past_key_value_state=None, + use_cache=False, training=False, ): + + if past_key_value_state is not None: + assert self.is_decoder, "Only decoder can use `past_key_value_states`" + expected_num_past_key_value_states = 2 if encoder_hidden_states is None else 4 + + error_message = "There should be {} past states. 2 (past / key) for self attention.{} Got {} past key / value states".format( + expected_num_past_key_value_states, + "2 (past / key) for cross attention" if expected_num_past_key_value_states == 4 else "", + len(past_key_value_state), + ) + assert len(past_key_value_state) == expected_num_past_key_value_states, error_message + + self_attn_past_key_value_state = past_key_value_state[:2] + cross_attn_past_key_value_state = past_key_value_state[2:] + else: + self_attn_past_key_value_state, cross_attn_past_key_value_state = None, None + self_attention_outputs = self.layer[0]( hidden_states, attention_mask=attention_mask, position_bias=position_bias, head_mask=head_mask, + past_key_value_state=self_attn_past_key_value_state, + use_cache=use_cache, training=training, ) - hidden_states = self_attention_outputs[0] - outputs = self_attention_outputs[1:] + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + + if self.is_decoder and encoder_hidden_states is not None: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = shape_list(present_key_value_state[0])[2] + else: + query_length = None - if not self.is_decoder: - hidden_states = self.layer[1](hidden_states, training=training) - else: cross_attention_outputs = self.layer[1]( hidden_states, kv=encoder_hidden_states, attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, head_mask=head_mask, + past_key_value_state=cross_attn_past_key_value_state, + query_length=query_length, + use_cache=use_cache, training=training, ) hidden_states = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:] - hidden_states = self.layer[2](hidden_states, training=training) + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = present_key_value_state + cross_attention_outputs[1] - outputs = (hidden_states,) + outputs # add attentions if we output them - return outputs # hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states, training=training) + outputs = (hidden_states,) + + # Add attentions if we output them + outputs = outputs + (present_key_value_state,) + attention_outputs + return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) class _NoLayerEmbedTokens(object): @@ -437,6 +543,8 @@ class TFT5MainLayer(tf.keras.layers.Layer): encoder_attention_mask=None, inputs_embeds=None, head_mask=None, + past_key_value_states=None, + use_cache=False, training=False, ): @@ -456,12 +564,26 @@ class TFT5MainLayer(tf.keras.layers.Layer): batch_size, seq_length = input_shape + if past_key_value_states is not None: + assert seq_length == 1, "Input shape is {}, but should be {} when using past_key_value_sates".format( + input_shape, (batch_size, 1) + ) + # required mask seq length can be calculated via length of past + # key value states and seq_length = 1 for the last token + mask_seq_length = shape_list(past_key_value_states[0][0])[2] + seq_length + else: + mask_seq_length = seq_length + if attention_mask is None: - attention_mask = tf.fill((batch_size, seq_length), 1) - if self.is_decoder and encoder_attention_mask is None: - encoder_seq_length = encoder_hidden_states.shape[1] + attention_mask = tf.fill((batch_size, mask_seq_length), 1) + if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: + encoder_seq_length = shape_list(encoder_hidden_states)[1] encoder_attention_mask = tf.fill((batch_size, encoder_seq_length), 1) + # initialize past_key_value_states with `None` if past does not exist + if past_key_value_states is None: + past_key_value_states = [None] * len(self.block) + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. attention_mask = tf.cast(attention_mask, dtype=tf.float32) @@ -469,16 +591,18 @@ class TFT5MainLayer(tf.keras.layers.Layer): if num_dims_attention_mask == 3: extended_attention_mask = attention_mask[:, None, :, :] elif num_dims_attention_mask == 2: - # Provided a padding mask of dimensions [batch_size, seq_length] + # Provided a padding mask of dimensions [batch_size, mask_seq_length] # - if the model is a decoder, apply a causal mask in addition to the padding mask - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] - if self.config.is_decoder: - seq_ids = tf.range(seq_length) + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + if self.is_decoder: + seq_ids = tf.range(mask_seq_length) causal_mask = tf.less_equal( - tf.tile(seq_ids[None, None, :], (batch_size, seq_length, 1)), seq_ids[None, :, None], + tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), seq_ids[None, :, None], ) causal_mask = tf.cast(causal_mask, dtype=tf.float32) extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + if past_key_value_states[0] is not None: + extended_attention_mask = extended_attention_mask[:, :, -1:, :] else: extended_attention_mask = attention_mask[:, None, None, :] @@ -495,8 +619,9 @@ class TFT5MainLayer(tf.keras.layers.Layer): extended_attention_mask = (1.0 - extended_attention_mask) * -1e9 - if self.is_decoder: + if self.is_decoder and encoder_attention_mask is not None: # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastabe to [batch_size, num_heads, mask_seq_length, mask_seq_length] # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=tf.float32) num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) @@ -525,13 +650,15 @@ class TFT5MainLayer(tf.keras.layers.Layer): head_mask = [None] * self.num_hidden_layers # head_mask = tf.constant([0] * self.num_hidden_layers) + present_key_value_states = () all_hidden_states = () all_attentions = () position_bias = None encoder_decoder_position_bias = None - hidden_states = inputs_embeds - for i, layer_module in enumerate(self.block): + hidden_states = self.dropout(inputs_embeds, training=training) + + for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_value_states)): if self.output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -543,18 +670,24 @@ class TFT5MainLayer(tf.keras.layers.Layer): encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, head_mask=head_mask[i], + past_key_value_state=past_key_value_state, + use_cache=use_cache, training=training, ) - hidden_states = layer_outputs[0] + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) + hidden_states, present_key_value_state = layer_outputs[:2] if i == 0: # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) - position_bias = layer_outputs[2 if self.output_attentions else 1] - if self.is_decoder: - encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 2] + position_bias = layer_outputs[3 if self.output_attentions else 2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 3] + # append next layer key value states + present_key_value_states = present_key_value_states + (present_key_value_state,) if self.output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) + all_attentions = all_attentions + (layer_outputs[2],) hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states, training=training) @@ -564,6 +697,9 @@ class TFT5MainLayer(tf.keras.layers.Layer): all_hidden_states = all_hidden_states + (hidden_states,) outputs = (hidden_states,) + if use_cache is True: + assert self.is_decoder, "`use_cache` can only be set to `True` if {} is used as a decoder".format(self) + outputs = outputs + (present_key_value_states,) if self.output_hidden_states: outputs = outputs + (all_hidden_states,) if self.output_attentions: @@ -650,6 +786,7 @@ T5_INPUTS_DOCSTRING = r""" :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. decoder_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`): Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation. + If `decoder_past_key_value_states` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_value_states`). attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: @@ -660,6 +797,13 @@ T5_INPUTS_DOCSTRING = r""" Used in the cross-attention of the decoder. decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`): Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. + decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains pre-computed key and value hidden-states of the attention blocks. + Can be used to speed up decoding. + If `decoder_past_key_value_states` are used, the user can optionally input only the last `decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + If `use_cache` is True, `decoder_past_key_value_states` are returned and can be used to speed up decoding (see `decoder_past_key_value_states`). inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors @@ -705,6 +849,12 @@ class TFT5Model(TFT5PreTrainedModel): def get_output_embeddings(self): return self.shared + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + @add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING) def call(self, inputs, **kwargs): r""" @@ -712,6 +862,11 @@ class TFT5Model(TFT5PreTrainedModel): :obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.T5Config`) and inputs. last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. + If `decoder_past_key_value_states` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output. + decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``use_cache=True``): + Contains pre-computed key and value hidden-states of the attention blocks. + Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input). + Note that when using `decoder_past_key_value_states`, the model only outputs the last `hidden-state` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`. hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``): Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. @@ -743,12 +898,14 @@ class TFT5Model(TFT5PreTrainedModel): # retrieve arguments input_ids = kwargs.get("inputs", None) - decoder_input_ids = kwargs.get("decoder_input_ids", None) + inputs_embeds = kwargs.get("inputs_embeds", None) attention_mask = kwargs.get("attention_mask", None) encoder_outputs = kwargs.get("encoder_outputs", None) + decoder_input_ids = kwargs.get("decoder_input_ids", None) decoder_attention_mask = kwargs.get("decoder_attention_mask", None) - inputs_embeds = kwargs.get("inputs_embeds", None) decoder_inputs_embeds = kwargs.get("decoder_inputs_embeds", None) + decoder_past_key_value_states = kwargs.get("decoder_past_key_value_states", None) + use_cache = kwargs.get("use_cache", True) head_mask = kwargs.get("head_mask", None) # Encode if needed (training, first prediction pass) @@ -759,16 +916,30 @@ class TFT5Model(TFT5PreTrainedModel): hidden_states = encoder_outputs[0] + # If decoding with past key value states, only the last tokens + # should be given as an input + if decoder_past_key_value_states is not None: + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + if decoder_inputs_embeds is not None: + decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] + # Decode decoder_outputs = self.decoder( decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, + past_key_value_states=decoder_past_key_value_states, encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, head_mask=head_mask, + use_cache=use_cache, ) + if use_cache is True: + past = ((encoder_outputs, decoder_outputs[1]),) + decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:] + return decoder_outputs + encoder_outputs @@ -802,6 +973,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel): def get_encoder(self): return self.encoder + def get_decoder(self): + return self.decoder + @add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING) def call(self, inputs, **kwargs): r""" @@ -811,6 +985,10 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel): Classification loss (cross entropy). prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``use_cache=True``): + Contains pre-computed key and value hidden-states of the attention blocks. + Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input). + Note that when using `decoder_past_key_value_states`, the model only outputs the last `prediction_score` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`. hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``): Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. @@ -850,6 +1028,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel): attention_mask = kwargs.get("attention_mask", None) encoder_outputs = kwargs.get("encoder_outputs", None) decoder_attention_mask = kwargs.get("decoder_attention_mask", None) + decoder_past_key_value_states = kwargs.get("decoder_past_key_value_states", None) + use_cache = kwargs.get("use_cache", True) inputs_embeds = kwargs.get("inputs_embeds", None) decoder_inputs_embeds = kwargs.get("decoder_inputs_embeds", None) head_mask = kwargs.get("head_mask", None) @@ -863,16 +1043,32 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel): hidden_states = encoder_outputs[0] + # If decoding with past key value states, only the last tokens + # should be given as an input + if decoder_past_key_value_states is not None: + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + if decoder_inputs_embeds is not None: + decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] + # Decode decoder_outputs = self.decoder( decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, + past_key_value_states=decoder_past_key_value_states, encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, head_mask=head_mask, + use_cache=use_cache, ) + # insert decoder past at right place + # to speed up decoding + if use_cache is True: + past = ((encoder_outputs, decoder_outputs[1]),) + decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:] + sequence_output = decoder_outputs[0] * (self.model_dim ** -0.5) embed_tokens = self.get_output_embeddings() lm_logits = embed_tokens(sequence_output, mode="linear") @@ -880,22 +1076,46 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel): return decoder_outputs + encoder_outputs - def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs): + def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, **kwargs): assert past is not None, "past has to be defined for encoder_outputs" # first step - if type(past) is tuple: - encoder_outputs = past + if len(past) < 2: + encoder_outputs, decoder_past_key_value_states = past, None else: - encoder_outputs = (past,) + encoder_outputs, decoder_past_key_value_states = past[0], past[1] return { "inputs": None, # inputs don't have to be defined, but still need to be passed to make Keras.layer.__call__ happy "decoder_input_ids": input_ids, # input_ids are the decoder_input_ids + "decoder_past_key_value_states": decoder_past_key_value_states, "encoder_outputs": encoder_outputs, "attention_mask": attention_mask, + "use_cache": use_cache, } def _reorder_cache(self, past, beam_idx): - # past does not have to be re-ordered for T5. - return past + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + + if len(past) < 2: + logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + return past + + decoder_past = past[1] + past = (past[0],) + reordered_decoder_past = () + + for layer_past_states in decoder_past: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + (tf.gather(layer_past_state, beam_idx),) + + assert shape_list(reordered_layer_past_states[0]) == shape_list(layer_past_states[0]) + assert len(reordered_layer_past_states) == len(layer_past_states) + + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return past + (reordered_decoder_past,) diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 6b15f13db5..8f5f7dedde 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -1299,17 +1299,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): @staticmethod def _reorder_cache(past, beam_idx): - reordered_past = [] - for layer_past in past: - # get the correct batch idx from layer past batch dim - # batch dim of `past` and `mems` is at 2nd position - reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[:, i], 1)) for i in beam_idx] - reordered_layer_past = tf.concat(reordered_layer_past, axis=1) - # check that shape matches - assert shape_list(reordered_layer_past) == shape_list(layer_past) - reordered_past.append(reordered_layer_past) - past = tuple(reordered_past) - return past + return tuple(tf.gather(layer_past, beam_idx, axis=1) for layer_past in past) def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty): diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index 27fe1cadfb..1daf1bf32b 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -244,7 +244,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() # test that outputs are equal for slice - self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-6)) def create_and_check_t5_decoder_model_attention_mask_past( self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, @@ -293,7 +293,6 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): def create_t5_and_check_t5_generate_with_past_key_value_states( self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, ): - config.num_layers = 1 model = T5ForConditionalGeneration(config=config) model.to(torch_device) model.eval() diff --git a/tests/test_modeling_tf_gpt2.py b/tests/test_modeling_tf_gpt2.py index 767fa3a2d0..f8b2ca8e3b 100644 --- a/tests/test_modeling_tf_gpt2.py +++ b/tests/test_modeling_tf_gpt2.py @@ -191,7 +191,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase): output_from_past_slice = output_from_past[:, 0, random_slice_idx] # test that outputs are equal for slice - tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-12) + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6) def create_and_check_gpt2_model_attention_mask_past( self, config, input_ids, input_mask, head_mask, token_type_ids, *args diff --git a/tests/test_modeling_tf_t5.py b/tests/test_modeling_tf_t5.py index 37a243b45c..de2ee07930 100644 --- a/tests/test_modeling_tf_t5.py +++ b/tests/test_modeling_tf_t5.py @@ -24,6 +24,7 @@ from .utils import CACHE_DIR, require_tf, slow if is_tf_available(): + import tensorflow as tf from transformers import TFT5Model, TFT5ForConditionalGeneration, T5Tokenizer @@ -111,14 +112,14 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): "decoder_input_ids": input_ids, "decoder_attention_mask": input_mask, } - encoder_output, decoder_output = model(inputs) + decoder_output, decoder_past, encoder_output = model(inputs) - encoder_output, decoder_output = model( + decoder_output, decoder_past, encoder_output = model( input_ids, decoder_attention_mask=input_mask, decoder_input_ids=input_ids ) - result = { "encoder_output": encoder_output.numpy(), + "decoder_past": decoder_past, "decoder_output": decoder_output.numpy(), } self.parent.assertListEqual( @@ -127,6 +128,13 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): self.parent.assertListEqual( list(result["decoder_output"].shape), [self.batch_size, self.seq_length, self.hidden_size] ) + self.parent.assertEqual(len(decoder_past), 2) + # decoder_past[0] should correspond to encoder output + self.parent.assertTrue(tf.reduce_all(tf.math.equal(decoder_past[0][0], encoder_output))) + # There should be `num_layers` key value embeddings stored in decoder_past[1] + self.parent.assertEqual(len(decoder_past[1]), config.num_layers) + # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple + self.parent.assertEqual(len(decoder_past[1][0]), 4) def create_and_check_t5_with_lm_head(self, config, input_ids, input_mask, token_labels): model = TFT5ForConditionalGeneration(config=config) @@ -136,7 +144,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): "decoder_attention_mask": input_mask, } - prediction_scores, decoder_output = model(inputs_dict) + prediction_scores, _, _ = model(inputs_dict) result = { "prediction_scores": prediction_scores.numpy(), @@ -145,6 +153,76 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size] ) + def create_and_check_t5_decoder_model_past(self, config, input_ids, decoder_input_ids, attention_mask): + model = TFT5Model(config=config).get_decoder() + + input_ids = input_ids[:1, :] + self.batch_size = 1 + + # first forward pass + _, past_key_value_states = model(input_ids, use_cache=True) + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + + output_from_no_past = model(next_input_ids)[0] + output_from_past = model(next_tokens, past_key_value_states=past_key_value_states)[0] + + # select random slice + random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1])) + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx] + output_from_past_slice = output_from_past[:, 0, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6) + + def create_and_check_t5_decoder_model_attention_mask_past( + self, config, input_ids, decoder_input_ids, attention_mask + ): + model = TFT5Model(config=config).get_decoder() + + # create attention mask + half_seq_length = self.seq_length // 2 + attn_mask_begin = tf.ones((self.batch_size, half_seq_length), dtype=tf.int32) + attn_mask_end = tf.zeros((self.batch_size, self.seq_length - half_seq_length), dtype=tf.int32) + attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1) + + # first forward pass + _, past_key_value_states = model(input_ids, attention_mask=attn_mask, use_cache=True) + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).numpy() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, self.seq_length), config.vocab_size) + vector_condition = tf.range(self.seq_length) == (self.seq_length - random_seq_idx_to_change) + condition = tf.transpose( + tf.broadcast_to(tf.expand_dims(vector_condition, -1), (self.seq_length, self.batch_size)) + ) + input_ids = tf.where(condition, random_other_next_tokens, input_ids) + + # append to next input_ids and attn_mask + next_input_ids = tf.concat([input_ids, next_tokens], axis=-1) + attn_mask = tf.concat([attn_mask, tf.ones((attn_mask.shape[0], 1), dtype=tf.int32)], axis=1,) + + # get two different outputs + output_from_no_past = model(next_input_ids, attention_mask=attn_mask)[0] + output_from_past = model( + next_tokens, past_key_value_states=past_key_value_states, attention_mask=attn_mask + )[0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).numpy().item() + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx] + output_from_past_slice = output_from_past[:, 0, random_slice_idx] + + # test that outputs are equal for slice + tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() (config, input_ids, input_mask, token_labels) = config_and_inputs @@ -152,6 +230,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): "inputs": input_ids, "decoder_input_ids": input_ids, "decoder_attention_mask": input_mask, + "use_cache": tf.convert_to_tensor([False]), } return config, inputs_dict @@ -170,6 +249,14 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs) + def test_t5_decoder_model_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_t5_decoder_model_past(*config_and_inputs) + + def test_t5_decoder_model_past_with_attn_mask(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_t5_decoder_model_attention_mask_past(*config_and_inputs) + @slow def test_model_from_pretrained(self): for model_name in ["t5-small"]: