diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index 5235629c60..9ac085495f 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -16,7 +16,6 @@ import copy -import itertools import logging import math import os @@ -185,13 +184,11 @@ class T5LayerFF(nn.Module): class T5Attention(nn.Module): - NEW_ID = itertools.count() - def __init__(self, config, has_relative_attention_bias=False): super().__init__() - self.layer_id = next(T5Attention.NEW_ID) self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias + self.output_past = config.output_past self.output_attentions = config.output_attentions self.relative_attention_num_buckets = config.relative_attention_num_buckets @@ -294,15 +291,37 @@ class T5Attention(nn.Module): values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen) return values - def forward(self, input, mask=None, kv=None, position_bias=None, cache=None, head_mask=None): + def forward( + self, + input, + mask=None, + kv=None, + position_bias=None, + past_key_value_state=None, + head_mask=None, + query_length=None, + ): """ Self-attention (if kv is None) or attention over source sentence (provided by kv). """ # Input is (bs, qlen, dim) # Mask is (bs, klen) (non-causal) or (bs, klen, klen) + # past_key_value_state[0] is (bs, n_heads, q_len - 1, dim_per_head) bs, qlen, dim = input.size() + + if past_key_value_state is not None: + assert self.is_decoder is True, "Encoder cannot cache past key value states" + assert ( + len(past_key_value_state) == 2 + ), "past_key_value_state should have 2 past states: keys and values. Got {} past states".format( + len(past_key_value_state) + ) + real_qlen = qlen + past_key_value_state[0].shape[2] if query_length is None else query_length + else: + real_qlen = qlen + if kv is None: - klen = qlen if cache is None else cache["slen"] + qlen + klen = real_qlen else: klen = kv.size(1) @@ -315,23 +334,27 @@ class T5Attention(nn.Module): return x.transpose(1, 2).contiguous().view(bs, -1, self.inner_dim) q = shape(self.q(input)) # (bs, n_heads, qlen, dim_per_head) + if kv is None: k = shape(self.k(input)) # (bs, n_heads, qlen, dim_per_head) v = shape(self.v(input)) # (bs, n_heads, qlen, dim_per_head) - elif cache is None or self.layer_id not in cache: + elif past_key_value_state is None: k = v = kv k = shape(self.k(k)) # (bs, n_heads, qlen, dim_per_head) v = shape(self.v(v)) # (bs, n_heads, qlen, dim_per_head) - if cache is not None: - if self.layer_id in cache: - if kv is None: - k_, v_ = cache[self.layer_id] - k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head) - v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head) - else: - k, v = cache[self.layer_id] - cache[self.layer_id] = (k, v) + if past_key_value_state is not None: + if kv is None: + k_, v_ = past_key_value_state + k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head) + v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head) + else: + k, v = past_key_value_state + + if self.is_decoder and self.output_past: + present_key_value_state = ((k, v),) + else: + present_key_value_state = (None,) # q = q / math.sqrt(dim_per_head) # No scaling in T5 scores = torch.einsum("bnqd,bnkd->bnqk", q, k) # (bs, n_heads, qlen, klen) @@ -339,7 +362,13 @@ class T5Attention(nn.Module): if position_bias is None: if not self.has_relative_attention_bias: raise ValueError("No position_bias provided and no weights to compute position_bias") - position_bias = self.compute_bias(qlen, klen) + position_bias = self.compute_bias(real_qlen, klen) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value_state is not None: + position_bias = position_bias[:, :, -1:, :] + if mask is not None: position_bias = position_bias + mask # (bs, n_heads, qlen, klen) @@ -357,6 +386,13 @@ class T5Attention(nn.Module): context = self.o(context) outputs = (context,) + + if self.output_past is False or self.is_decoder is False: + assert ( + present_key_value_state[0] is None + ), "Key/Value projections should not be stored if {} is not decoder or output_past is False".format(self) + + outputs = outputs + present_key_value_state if self.output_attentions: outputs = outputs + (weights,) if self.has_relative_attention_bias: @@ -371,10 +407,16 @@ class T5LayerSelfAttention(nn.Module): self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) - def forward(self, hidden_states, attention_mask=None, position_bias=None, head_mask=None): + def forward( + self, hidden_states, attention_mask=None, position_bias=None, head_mask=None, past_key_value_state=None + ): norm_x = self.layer_norm(hidden_states) attention_output = self.SelfAttention( - norm_x, mask=attention_mask, position_bias=position_bias, head_mask=head_mask + norm_x, + mask=attention_mask, + position_bias=position_bias, + head_mask=head_mask, + past_key_value_state=past_key_value_state, ) y = attention_output[0] layer_output = hidden_states + self.dropout(y) @@ -389,10 +431,25 @@ class T5LayerCrossAttention(nn.Module): self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) - def forward(self, hidden_states, kv, attention_mask=None, position_bias=None, head_mask=None): + def forward( + self, + hidden_states, + kv, + attention_mask=None, + position_bias=None, + head_mask=None, + past_key_value_state=None, + query_length=None, + ): norm_x = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( - norm_x, mask=attention_mask, kv=kv, position_bias=position_bias, head_mask=head_mask + norm_x, + mask=attention_mask, + kv=kv, + position_bias=position_bias, + head_mask=head_mask, + past_key_value_state=past_key_value_state, + query_length=query_length, ) y = attention_output[0] layer_output = hidden_states + self.dropout(y) @@ -403,14 +460,14 @@ class T5LayerCrossAttention(nn.Module): class T5Block(nn.Module): def __init__(self, config, has_relative_attention_bias=False): super().__init__() + self.output_past = config.output_past self.is_decoder = config.is_decoder self.layer = nn.ModuleList() self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) if self.is_decoder: self.layer.append(T5LayerCrossAttention(config, has_relative_attention_bias=has_relative_attention_bias)) - self.layer.append(T5LayerFF(config)) - else: - self.layer.append(T5LayerFF(config)) + + self.layer.append(T5LayerFF(config)) def forward( self, @@ -421,31 +478,63 @@ class T5Block(nn.Module): encoder_attention_mask=None, encoder_decoder_position_bias=None, head_mask=None, + past_key_value_state=None, ): - self_attention_outputs = self.layer[0]( - hidden_states, attention_mask=attention_mask, position_bias=position_bias, head_mask=head_mask - ) - hidden_states = self_attention_outputs[0] - outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights - if not self.is_decoder: - hidden_states = self.layer[1](hidden_states) + if past_key_value_state is not None: + assert self.is_decoder, "Only decoder can use `past_key_value_states`" + assert ( + len(past_key_value_state) == 4 + ), "The should be 4 past states. 2 (past / key) for self attention. 2 (past / key) for cross attention. Got {} past key / value states".format( + len(past_key_value_state) + ) + self_attn_past_key_value_state = past_key_value_state[:2] + cross_attn_past_key_value_state = past_key_value_state[2:] else: + self_attn_past_key_value_state, cross_attn_past_key_value_state = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + head_mask=head_mask, + past_key_value_state=self_attn_past_key_value_state, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + + if self.is_decoder: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = present_key_value_state[0].shape[2] + else: + query_length = None + cross_attention_outputs = self.layer[1]( hidden_states, kv=encoder_hidden_states, attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, head_mask=head_mask, + past_key_value_state=cross_attn_past_key_value_state, + query_length=query_length, ) hidden_states = cross_attention_outputs[0] - outputs = ( - outputs + cross_attention_outputs[1:] - ) # Keep cross-attention outputs and relative position weights - hidden_states = self.layer[2](hidden_states) + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = present_key_value_state + cross_attention_outputs[1] - outputs = (hidden_states,) + outputs # add attentions if we output them - return outputs # hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + outputs = (hidden_states,) + + # Add attentions if we output them + outputs = outputs + (present_key_value_state,) + attention_outputs + return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) class T5PreTrainedModel(PreTrainedModel): @@ -531,6 +620,7 @@ class T5Stack(T5PreTrainedModel): self.embed_tokens = embed_tokens self.is_decoder = config.is_decoder + self.output_past = config.output_past and self.is_decoder self.block = nn.ModuleList( [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] @@ -557,6 +647,7 @@ class T5Stack(T5PreTrainedModel): encoder_attention_mask=None, inputs_embeds=None, head_mask=None, + past_key_value_states=None, ): if input_ids is not None and inputs_embeds is not None: @@ -575,25 +666,41 @@ class T5Stack(T5PreTrainedModel): batch_size, seq_length = input_shape + if past_key_value_states is not None: + assert seq_length == 1, "Input shape is {}, but should be {} when using past_key_value_sates".format( + input_shape, (batch_size, 1) + ) + # required mask seq length can be calculated via length of past + # key value states and seq_length = 1 for the last token + mask_seq_length = past_key_value_states[0][0].shape[2] + seq_length + else: + mask_seq_length = seq_length + if attention_mask is None: - attention_mask = torch.ones(batch_size, seq_length).to(inputs_embeds.device) - if self.is_decoder and encoder_attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) + if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: encoder_seq_length = encoder_hidden_states.shape[1] encoder_attention_mask = torch.ones(batch_size, encoder_seq_length).to(inputs_embeds.device) + # initialize past_key_value_states with `None` if past does not exist + if past_key_value_states is None: + past_key_value_states = [None] * len(self.block) + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. if attention_mask.dim() == 3: extended_attention_mask = attention_mask[:, None, :, :] elif attention_mask.dim() == 2: - # Provided a padding mask of dimensions [batch_size, seq_length] + # Provided a padding mask of dimensions [batch_size, mask_seq_length] # - if the model is a decoder, apply a causal mask in addition to the padding mask - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] if self.config.is_decoder: - seq_ids = torch.arange(seq_length, device=inputs_embeds.device) - causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + seq_ids = torch.arange(mask_seq_length, device=inputs_embeds.device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, mask_seq_length, 1) <= seq_ids[None, :, None] causal_mask = causal_mask.to(attention_mask) extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + if self.output_past and past_key_value_states[0] is not None: + extended_attention_mask = extended_attention_mask[:, :, -1:, :] else: extended_attention_mask = attention_mask[:, None, None, :] @@ -610,9 +717,9 @@ class T5Stack(T5PreTrainedModel): extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -1e9 - if self.is_decoder: + if self.is_decoder and encoder_attention_mask is not None: # If a 2D ou 3D attention mask is provided for the cross-attention - # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] + # we need to make broadcastabe to [batch_size, num_heads, mask_seq_length, mask_seq_length] if encoder_attention_mask.dim() == 3: encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] if encoder_attention_mask.dim() == 2: @@ -633,7 +740,7 @@ class T5Stack(T5PreTrainedModel): # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x mask_seq_length x mask_seq_length] if head_mask is not None: if head_mask.dim() == 1: head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) @@ -648,13 +755,15 @@ class T5Stack(T5PreTrainedModel): else: head_mask = [None] * self.config.num_layers + present_key_value_states = () all_hidden_states = () all_attentions = () position_bias = None encoder_decoder_position_bias = None hidden_states = self.dropout(inputs_embeds) - for i, layer_module in enumerate(self.block): + + for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_value_states)): if self.output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -666,19 +775,22 @@ class T5Stack(T5PreTrainedModel): encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, head_mask=head_mask[i], + past_key_value_state=past_key_value_state, ) # layer_outputs is a tuple with: - # hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) - hidden_states = layer_outputs[0] + # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) + hidden_states, present_key_value_state = layer_outputs[:2] if i == 0: # We share the position biases between the layers - the first layer store them - # layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) - position_bias = layer_outputs[2 if self.output_attentions else 1] + # layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) + position_bias = layer_outputs[3 if self.output_attentions else 2] if self.is_decoder: - encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 2] + encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 3] + # append next layer key value states + present_key_value_states = present_key_value_states + (present_key_value_state,) if self.output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) # We keep only self-attention weights for now + all_attentions = all_attentions + (layer_outputs[2],) # We keep only self-attention weights for now hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) @@ -688,11 +800,13 @@ class T5Stack(T5PreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) outputs = (hidden_states,) + if self.is_decoder and self.output_past: + outputs = outputs + (present_key_value_states,) if self.output_hidden_states: outputs = outputs + (all_hidden_states,) if self.output_attentions: outputs = outputs + (all_attentions,) - return outputs # last-layer hidden state, (all hidden states), (all attentions) + return outputs # last-layer hidden state, (presents,) (all hidden states), (all attentions) T5_START_DOCSTRING = r""" The T5 model was proposed in @@ -719,7 +833,7 @@ T5_INPUTS_DOCSTRING = r""" Args: input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. - T5 is a model with relative position embeddings so you should be able to pad the inputs on + T5 is a model with relative position embeddings so you should be able to pad the inputs on both the right and the left. If `decoder_past_key_value_states` is used, optionally only the last `input_ids` have to be input (see `decoder_past_key_value_states`). Indices can be obtained using :class:`transformers.T5Tokenizer`. See :func:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. @@ -739,6 +853,9 @@ T5_INPUTS_DOCSTRING = r""" `T5 Training <./t5.html#training>`_ . decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`): Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. + decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains pre-computed key and value hidden-states of the attention blocks. + Can be used to speed up decoding. If `decoder_past_key_value_states` are used, the user can optionally input only the last `decoder_input_ids` of shape :obj:`(batch_size, 1)` instead of all `decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors @@ -780,6 +897,20 @@ class T5Model(T5PreTrainedModel): self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) + def set_output_past(self, do_output_past: bool): + self.config.output_past = do_output_past + self.decoder.output_past = do_output_past + for block in self.decoder.block: + block.output_past = do_output_past + block.layer[0].SelfAttention.output_past = do_output_past + block.layer[1].EncDecAttention.output_past = do_output_past + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} @@ -796,6 +927,7 @@ class T5Model(T5PreTrainedModel): encoder_outputs=None, decoder_input_ids=None, decoder_attention_mask=None, + decoder_past_key_value_states=None, inputs_embeds=None, decoder_inputs_embeds=None, head_mask=None, @@ -805,6 +937,11 @@ class T5Model(T5PreTrainedModel): :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.T5Config`) and inputs. last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. + If `decoder_past_key_value_states` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output. + decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``config.output_past=True``): + Contains pre-computed key and value hidden-states of the attention blocks. + Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input). + Note that when using `decoder_past_key_value_states`, the model only outputs the last `hidden-state` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`. hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. @@ -837,16 +974,29 @@ class T5Model(T5PreTrainedModel): hidden_states = encoder_outputs[0] + # If decoding with past key value states, only the last tokens + # should be given as an input + if decoder_past_key_value_states is not None and self.decoder.output_past is True: + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + if decoder_inputs_embeds is not None: + decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] + # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, + past_key_value_states=decoder_past_key_value_states, encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, head_mask=head_mask, ) + if self.decoder.output_past: + past = ((encoder_outputs, decoder_outputs[1]),) + decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:] + return decoder_outputs + encoder_outputs @@ -872,6 +1022,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel): def get_input_embeddings(self): return self.shared + def set_output_past(self, do_output_past: bool): + self.config.output_past = do_output_past + self.decoder.output_past = do_output_past + for block in self.decoder.block: + block.output_past = do_output_past + block.layer[0].SelfAttention.output_past = do_output_past + block.layer[1].EncDecAttention.output_past = do_output_past + def set_input_embeddings(self, new_embeddings): self.shared = new_embeddings self.encoder.set_input_embeddings(new_embeddings) @@ -883,6 +1041,9 @@ class T5ForConditionalGeneration(T5PreTrainedModel): def get_encoder(self): return self.encoder + def get_decoder(self): + return self.decoder + @add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING) def forward( self, @@ -891,6 +1052,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): encoder_outputs=None, decoder_input_ids=None, decoder_attention_mask=None, + decoder_past_key_value_states=None, lm_labels=None, inputs_embeds=None, decoder_inputs_embeds=None, @@ -909,10 +1071,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel): Classification loss (cross entropy). prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + If `past_key_value_states` is used only the last prediction_scores of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output. + decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``config.output_past=True``): + Contains pre-computed key and value hidden-states of the attention blocks. + Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input). + Note that when using `decoder_past_key_value_states`, the model only outputs the last `prediction_score` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`. hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. - Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape @@ -948,16 +1114,34 @@ class T5ForConditionalGeneration(T5PreTrainedModel): # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(lm_labels) + # If decoding with past key value states, only the last tokens + # should be given as an input + if decoder_past_key_value_states is not None and self.decoder.output_past is True: + assert ( + lm_labels is None + ), "Decoder should not use cached key value states when training. Also consider setting model.set_output_past(False) for less memory consumption" + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + if decoder_inputs_embeds is not None: + decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] + # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, + past_key_value_states=decoder_past_key_value_states, encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, head_mask=head_mask, ) + # insert decoder past at right place + # to speed up decoding + if self.decoder.output_past: + past = ((encoder_outputs, decoder_outputs[1]),) + decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:] + sequence_output = decoder_outputs[0] # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 @@ -968,9 +1152,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel): if lm_labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-100) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)) - decoder_outputs = ( - loss, - ) + decoder_outputs # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + decoder_outputs = (loss,) + decoder_outputs return decoder_outputs + encoder_outputs @@ -978,17 +1161,40 @@ class T5ForConditionalGeneration(T5PreTrainedModel): assert past is not None, "past has to be defined for encoder_outputs" # first step - if type(past) is tuple: - encoder_outputs = past + if len(past) < 2: + encoder_outputs, decoder_past_key_value_states = past, None else: - encoder_outputs = (past,) + encoder_outputs, decoder_past_key_value_states = past[0], past[1] return { "decoder_input_ids": input_ids, + "decoder_past_key_value_states": decoder_past_key_value_states, "encoder_outputs": encoder_outputs, "attention_mask": attention_mask, } def _reorder_cache(self, past, beam_idx): - # past does not have to be re-ordered for T5. - return past + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if len(past) < 2: + logger.warning("You might want to consider setting model.set_output_past(True) to speed up decoding") + return past + + decoder_past = past[1] + past = (past[0],) + reordered_decoder_past = () + for layer_past_states in decoder_past: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select(0, beam_idx), + ) + + assert reordered_layer_past_states[0].shape == layer_past_states[0].shape + assert len(reordered_layer_past_states) == len(layer_past_states) + + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return past + (reordered_decoder_past,) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 685605e773..f52e2b6fa2 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1417,17 +1417,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): @staticmethod def _reorder_cache(past, beam_idx): - reordered_past = [] - for layer_past in past: - # get the correct batch idx from layer past batch dim - # batch dim of `past` and `mems` is at 2nd position - reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx] - reordered_layer_past = torch.cat(reordered_layer_past, dim=1) - # check that shape matches - assert reordered_layer_past.shape == layer_past.shape - reordered_past.append(reordered_layer_past) - past = tuple(reordered_past) - return past + return tuple(layer_past.index_select(1, beam_idx) for layer_past in past) def calc_banned_ngram_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 7ac52a672d..31494a066d 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -128,6 +128,7 @@ class ModelTesterMixin: for model_class in self.all_model_classes: config.output_attentions = True config.output_hidden_states = False + config.output_past = False model = model_class(config) model.to(torch_device) model.eval() @@ -144,10 +145,9 @@ class ModelTesterMixin: out_len = len(outputs) if self.is_encoder_decoder: - correct_outlen = ( - 4 # decoder_features_or_logits, decoder_attentions, encoder_features, encoder_attentions - ) + correct_outlen = 4 decoder_attention_idx = 1 + if "lm_labels" in inputs_dict: # loss will come first correct_outlen += 1 # compute loss decoder_attention_idx += 1 diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index 1d5e4e9a0a..041ac4a2d9 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -167,17 +167,20 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): model = T5Model(config=config) model.to(torch_device) model.eval() - decoder_output, encoder_output = model( + decoder_output, decoder_past, encoder_output = model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, ) - decoder_output, encoder_output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + decoder_output, decoder_past, encoder_output = model( + input_ids=input_ids, decoder_input_ids=decoder_input_ids + ) result = { "encoder_output": encoder_output, "decoder_output": decoder_output, + "decoder_past": decoder_past, } self.parent.assertListEqual( list(result["encoder_output"].size()), [self.batch_size, self.encoder_seq_length, self.hidden_size] @@ -185,6 +188,13 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): self.parent.assertListEqual( list(result["decoder_output"].size()), [self.batch_size, self.decoder_seq_length, self.hidden_size] ) + self.parent.assertEqual(len(decoder_past), 2) + # decoder_past[0] should correspond to encoder output + self.parent.assertTrue(torch.all(decoder_past[0][0] == encoder_output)) + # There should be `num_layers` key value embeddings stored in decoder_past[1] + self.parent.assertEqual(len(decoder_past[1]), config.num_layers) + # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple + self.parent.assertEqual(len(decoder_past[1][0]), 4) def create_and_check_t5_with_lm_head( self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, @@ -198,8 +208,8 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): decoder_attention_mask=decoder_attention_mask, lm_labels=lm_labels, ) - loss, prediction_scores, encoder_features = outputs - self.parent.assertEqual(len(outputs), 3) + loss, prediction_scores, _, _ = outputs + self.parent.assertEqual(len(outputs), 4) result = { "loss": loss, "prediction_scores": prediction_scores, @@ -209,6 +219,92 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): ) self.check_loss_output(result) + def create_and_check_t5_decoder_model_past( + self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, + ): + model = T5Model(config=config).get_decoder() + model.to(torch_device) + model.eval() + + # first forward pass + output, past_key_value_states = model(input_ids) + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + + output_from_no_past, _ = model(next_input_ids) + output_from_past, _ = model(next_tokens, past_key_value_states=past_key_value_states) + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_t5_decoder_model_attention_mask_past( + self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, + ): + model = T5Model(config=config).get_decoder() + model.to(torch_device) + model.eval() + + # create attention mask + attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) + + half_seq_length = input_ids.shape[-1] // 2 + attn_mask[:, half_seq_length:] = 0 + + # first forward pass + output, past_key_value_states = model(input_ids, attention_mask=attn_mask) + + # create hypothetical next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + + # change a random masked slice from input_ids + random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1 + random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1) + input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens + + # append to next input_ids and attn_mask + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + attn_mask = torch.cat( + [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], dim=1, + ) + + # get two different outputs + output_from_no_past, _ = model(next_input_ids, attention_mask=attn_mask) + output_from_past, _ = model( + next_tokens, past_key_value_states=past_key_value_states, attention_mask=attn_mask + ) + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_t5_and_check_t5_generate_with_past_key_value_states( + self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, + ): + config.num_layers = 1 + model = T5ForConditionalGeneration(config=config) + model.to(torch_device) + model.eval() + torch.manual_seed(0) + model.set_output_past(False) + output_without_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True) + torch.manual_seed(0) + model.set_output_past(True) + output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True) + self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache)) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -247,6 +343,18 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs) + def test_t5_decoder_model_past(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_t5_decoder_model_past(*config_and_inputs) + + def test_t5_decoder_model_past_with_attn_mask(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_t5_decoder_model_attention_mask_past(*config_and_inputs) + + def test_t5_generate_with_past_key_value_states(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_t5_and_check_t5_generate_with_past_key_value_states(*config_and_inputs) + @slow def test_model_from_pretrained(self): for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: