diff --git a/docs/source/main_classes/output.rst b/docs/source/main_classes/output.rst index 109bc49e1c..044dfd82ce 100644 --- a/docs/source/main_classes/output.rst +++ b/docs/source/main_classes/output.rst @@ -126,13 +126,6 @@ CausalLMOutputWithCrossAttentions :members: -CausalLMOutputWithPastAndCrossAttentions -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: transformers.modeling_outputs.CausalLMOutputWithPastAndCrossAttentions - :members: - - CausalLMOutputWithPast ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/modeling_outputs.py b/src/transformers/modeling_outputs.py index f08ca2c29b..0c4c503905 100644 --- a/src/transformers/modeling_outputs.py +++ b/src/transformers/modeling_outputs.py @@ -175,11 +175,19 @@ class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput): Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + Tuple of :obj:`torch.FloatTensor` tuples of length :obj:`config.n_layers`, with each tuple containing the + cached key, value states of the self-attention and the cross-attention layers if model is used in + encoder-decoder setting. Only relevant if ``config.is_decoder = True``. + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + :obj:`past_key_values` input) to speed up sequential decoding. """ last_hidden_state: torch.FloatTensor = None pooler_output: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None cross_attentions: Optional[Tuple[torch.FloatTensor]] = None @@ -379,53 +387,18 @@ class CausalLMOutputWithCrossAttentions(ModelOutput): Cross attentions weights after the attention softmax, used to compute the weighted average in the cross-attention heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor]] = None - - -@dataclass -class CausalLMOutputWithPastAndCrossAttentions(ModelOutput): - """ - Base class for causal language model (or autoregressive) outputs. - - Args: - loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): - Language modeling loss (for next-token prediction). - logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): - List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, - batch_size, num_heads, sequence_length, embed_size_per_head)`). + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + Tuple of :obj:`torch.FloatTensor` tuples of length :obj:`config.n_layers`, with each tuple containing the + cached key, value states of the self-attention and the cross-attention layers if model is used in + encoder-decoder setting. Only relevant if ``config.is_decoder = True``. Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding. - hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): - Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) - of shape :obj:`(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): - Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, - sequence_length, sequence_length)`. - - Cross attentions weights after the attention softmax, used to compute the weighted average in the - cross-attention heads. """ loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None - past_key_values: Optional[List[torch.FloatTensor]] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None cross_attentions: Optional[Tuple[torch.FloatTensor]] = None diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index f0505f8078..325e95e195 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -217,7 +217,9 @@ class AlbertEmbeddings(nn.Module): self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward - def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): + def forward( + self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): if input_ids is not None: input_shape = input_ids.size() else: @@ -226,7 +228,7 @@ class AlbertEmbeddings(nn.Module): seq_length = input_shape[1] if position_ids is None: - position_ids = self.position_ids[:, :seq_length] + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] if token_type_ids is None: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) diff --git a/src/transformers/models/bert/configuration_bert.py b/src/transformers/models/bert/configuration_bert.py index e265375e8d..5555704858 100644 --- a/src/transformers/models/bert/configuration_bert.py +++ b/src/transformers/models/bert/configuration_bert.py @@ -98,6 +98,9 @@ class BertConfig(PretrainedConfig): `__. For more information on :obj:`"relative_key_query"`, please refer to `Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.) `__. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if ``config.is_decoder=True``. Examples:: @@ -131,6 +134,7 @@ class BertConfig(PretrainedConfig): pad_token_id=0, gradient_checkpointing=False, position_embedding_type="absolute", + use_cache=True, **kwargs ): super().__init__(pad_token_id=pad_token_id, **kwargs) @@ -149,3 +153,4 @@ class BertConfig(PretrainedConfig): self.layer_norm_eps = layer_norm_eps self.gradient_checkpointing = gradient_checkpointing self.position_embedding_type = position_embedding_type + self.use_cache = use_cache diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 46649f2413..b4d11fda43 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -36,7 +36,7 @@ from ...file_utils import ( replace_return_docstrings, ) from ...modeling_outputs import ( - BaseModelOutputWithCrossAttentions, + BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, MaskedLMOutput, @@ -180,7 +180,9 @@ class BertEmbeddings(nn.Module): self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") - def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): + def forward( + self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): if input_ids is not None: input_shape = input_ids.size() else: @@ -189,7 +191,7 @@ class BertEmbeddings(nn.Module): seq_length = input_shape[1] if position_ids is None: - position_ids = self.position_ids[:, :seq_length] + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] if token_type_ids is None: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) @@ -230,6 +232,8 @@ class BertSelfAttention(nn.Module): self.max_position_embeddings = config.max_position_embeddings self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + self.is_decoder = config.is_decoder + def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) @@ -242,6 +246,7 @@ class BertSelfAttention(nn.Module): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + past_key_value=None, output_attentions=False, ): mixed_query_layer = self.query(hidden_states) @@ -249,17 +254,37 @@ class BertSelfAttention(nn.Module): # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. - if encoder_hidden_states is not None: - mixed_key_layer = self.key(encoder_hidden_states) - mixed_value_layer = self.value(encoder_hidden_states) + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) else: - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -303,6 +328,9 @@ class BertSelfAttention(nn.Module): context_layer = context_layer.view(*new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) return outputs @@ -352,6 +380,7 @@ class BertAttention(nn.Module): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + past_key_value=None, output_attentions=False, ): self_outputs = self.self( @@ -360,6 +389,7 @@ class BertAttention(nn.Module): head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) @@ -417,36 +447,60 @@ class BertLayer(nn.Module): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + past_key_value=None, output_attentions=False, ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, ) attention_output = self_attention_outputs[0] - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: assert hasattr( self, "crossattention" ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, + cross_attn_past_key_value, output_attentions, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + return outputs def feed_forward_chunk(self, attention_output): @@ -468,6 +522,8 @@ class BertEncoder(nn.Module): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + past_key_values=None, + use_cache=None, output_attentions=False, output_hidden_states=False, return_dict=True, @@ -475,17 +531,19 @@ class BertEncoder(nn.Module): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - + past_key_value = past_key_values[i] if past_key_values is not None else None if getattr(self.config, "gradient_checkpointing", False): def create_custom_forward(module): def custom_forward(*inputs): - return module(*inputs, output_attentions) + return module(*inputs, past_key_value, output_attentions) return custom_forward @@ -504,9 +562,13 @@ class BertEncoder(nn.Module): layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, output_attentions, ) + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -518,11 +580,18 @@ class BertEncoder(nn.Module): if not return_dict: return tuple( v - for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions] + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None ) - return BaseModelOutputWithCrossAttentions( + return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -799,6 +868,8 @@ class BertModel(BertPreTrainedModel): inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, + past_key_values=None, + use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, @@ -813,6 +884,15 @@ class BertModel(BertPreTrainedModel): - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -820,19 +900,29 @@ class BertModel(BertPreTrainedModel): ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() + batch_size, seq_length = input_shape elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape else: raise ValueError("You have to specify either input_ids or inputs_embeds") device = input_ids.device if input_ids is not None else inputs_embeds.device + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if attention_mask is None: - attention_mask = torch.ones(input_shape, device=device) + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) if token_type_ids is None: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) @@ -859,7 +949,11 @@ class BertModel(BertPreTrainedModel): head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) embedding_output = self.embeddings( - input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, ) encoder_outputs = self.encoder( embedding_output, @@ -867,6 +961,8 @@ class BertModel(BertPreTrainedModel): head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -880,6 +976,7 @@ class BertModel(BertPreTrainedModel): return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, cross_attentions=encoder_outputs.cross_attentions, @@ -1029,6 +1126,8 @@ class BertLMHeadModel(BertPreTrainedModel): encoder_hidden_states=None, encoder_attention_mask=None, labels=None, + past_key_values=None, + use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, @@ -1047,6 +1146,15 @@ class BertLMHeadModel(BertPreTrainedModel): Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). Returns: @@ -1066,6 +1174,8 @@ class BertLMHeadModel(BertPreTrainedModel): >>> prediction_logits = outputs.logits """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False outputs = self.bert( input_ids, @@ -1076,6 +1186,8 @@ class BertLMHeadModel(BertPreTrainedModel): inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -1099,20 +1211,30 @@ class BertLMHeadModel(BertPreTrainedModel): return CausalLMOutputWithCrossAttentions( loss=lm_loss, logits=prediction_scores, + past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, cross_attentions=outputs.cross_attentions, ) - def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): input_shape = input_ids.shape - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + return {"input_ids": input_ids, "attention_mask": attention_mask} + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + @add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING) class BertForMaskedLM(BertPreTrainedModel): diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 7336fd005c..37d4a5dd37 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -26,7 +26,7 @@ from ...file_utils import ( add_start_docstrings_to_model_forward, replace_return_docstrings, ) -from ...modeling_outputs import BaseModelOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions +from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...utils import logging from ..bert.modeling_bert import BertEncoder @@ -144,7 +144,7 @@ class BertGenerationEmbeddings(nn.Module): # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) - def forward(self, input_ids=None, position_ids=None, inputs_embeds=None): + def forward(self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0): if input_ids is not None: input_shape = input_ids.size() else: @@ -153,7 +153,7 @@ class BertGenerationEmbeddings(nn.Module): seq_length = input_shape[1] if position_ids is None: - position_ids = self.position_ids[:, :seq_length] + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) @@ -297,7 +297,7 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel): @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/bert_for_seq_generation_L-24_bbc_encoder", - output_type=BaseModelOutputWithCrossAttentions, + output_type=BaseModelOutputWithPastAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) def forward( @@ -309,6 +309,8 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel): inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, + past_key_values=None, + use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, @@ -321,6 +323,15 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -332,19 +343,28 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel): raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() + batch_size, seq_length = input_shape elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape else: raise ValueError("You have to specify either input_ids or inputs_embeds") device = input_ids.device if input_ids is not None else inputs_embeds.device + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if attention_mask is None: - attention_mask = torch.ones(input_shape, device=device) + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) + extended_attention_mask = None + if not use_cache: + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, device + ) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -364,7 +384,12 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel): # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds) + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) encoder_outputs = self.encoder( embedding_output, @@ -372,6 +397,8 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel): head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -381,8 +408,9 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel): if not return_dict: return (sequence_output,) + encoder_outputs[1:] - return BaseModelOutputWithCrossAttentions( + return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=sequence_output, + past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, cross_attentions=encoder_outputs.cross_attentions, @@ -437,6 +465,8 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel): encoder_hidden_states=None, encoder_attention_mask=None, labels=None, + past_key_values=None, + use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, @@ -455,6 +485,15 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel): Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). Returns: @@ -474,6 +513,8 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel): >>> prediction_logits = outputs.logits """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False outputs = self.bert( input_ids, @@ -483,6 +524,8 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel): inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -506,16 +549,26 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel): return CausalLMOutputWithCrossAttentions( loss=lm_loss, logits=prediction_scores, + past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, cross_attentions=outputs.cross_attentions, ) - def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): input_shape = input_ids.shape - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + return {"input_ids": input_ids, "attention_mask": attention_mask} + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index dbd16516de..fcab923cd9 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -33,6 +33,7 @@ from ...file_utils import ( ) from ...modeling_outputs import ( BaseModelOutputWithCrossAttentions, + BaseModelOutputWithPastAndCrossAttentions, MaskedLMOutput, MultipleChoiceModelOutput, QuestionAnsweringModelOutput, @@ -168,7 +169,9 @@ class ElectraEmbeddings(nn.Module): self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward - def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): + def forward( + self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): if input_ids is not None: input_shape = input_ids.size() else: @@ -177,7 +180,7 @@ class ElectraEmbeddings(nn.Module): seq_length = input_shape[1] if position_ids is None: - position_ids = self.position_ids[:, :seq_length] + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] if token_type_ids is None: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) @@ -219,6 +222,8 @@ class ElectraSelfAttention(nn.Module): self.max_position_embeddings = config.max_position_embeddings self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + self.is_decoder = config.is_decoder + def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) @@ -231,6 +236,7 @@ class ElectraSelfAttention(nn.Module): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + past_key_value=None, output_attentions=False, ): mixed_query_layer = self.query(hidden_states) @@ -238,17 +244,37 @@ class ElectraSelfAttention(nn.Module): # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. - if encoder_hidden_states is not None: - mixed_key_layer = self.key(encoder_hidden_states) - mixed_value_layer = self.value(encoder_hidden_states) + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) else: - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -292,6 +318,9 @@ class ElectraSelfAttention(nn.Module): context_layer = context_layer.view(*new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) return outputs @@ -343,6 +372,7 @@ class ElectraAttention(nn.Module): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + past_key_value=None, output_attentions=False, ): self_outputs = self.self( @@ -351,6 +381,7 @@ class ElectraAttention(nn.Module): head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) @@ -411,36 +442,60 @@ class ElectraLayer(nn.Module): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + past_key_value=None, output_attentions=False, ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, ) attention_output = self_attention_outputs[0] - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: assert hasattr( self, "crossattention" ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, + cross_attn_past_key_value, output_attentions, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + return outputs def feed_forward_chunk(self, attention_output): @@ -463,6 +518,8 @@ class ElectraEncoder(nn.Module): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + past_key_values=None, + use_cache=None, output_attentions=False, output_hidden_states=False, return_dict=True, @@ -470,17 +527,19 @@ class ElectraEncoder(nn.Module): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - + past_key_value = past_key_values[i] if past_key_values is not None else None if getattr(self.config, "gradient_checkpointing", False): def create_custom_forward(module): def custom_forward(*inputs): - return module(*inputs, output_attentions) + return module(*inputs, past_key_value, output_attentions) return custom_forward @@ -499,9 +558,13 @@ class ElectraEncoder(nn.Module): layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, output_attentions, ) + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -513,11 +576,18 @@ class ElectraEncoder(nn.Module): if not return_dict: return tuple( v - for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions] + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None ) - return BaseModelOutputWithCrossAttentions( + return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 5c27559fd8..e306d5f0c0 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -345,11 +345,11 @@ class EncoderDecoderModel(PreTrainedModel): decoder_input_ids=None, decoder_attention_mask=None, encoder_outputs=None, - past_key_values=None, # TODO: (PVP) implement :obj:`use_cache` + past_key_values=None, inputs_embeds=None, decoder_inputs_embeds=None, labels=None, - use_cache=None, # TODO: (PVP) implement :obj:`use_cache` + use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, @@ -413,18 +413,19 @@ class EncoderDecoderModel(PreTrainedModel): labels=labels, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + use_cache=use_cache, + past_key_values=past_key_values, return_dict=return_dict, **kwargs_decoder, ) - # TODO(PVP): currently it is not possible to use `past` if not return_dict: return decoder_outputs + encoder_outputs return Seq2SeqLMOutput( loss=decoder_outputs.loss, logits=decoder_outputs.logits, - past_key_values=None, # TODO(PVP) - need to implement cache for BERT, etc... before this works + past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, @@ -433,24 +434,19 @@ class EncoderDecoderModel(PreTrainedModel): encoder_attentions=encoder_outputs.attentions, ) - def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, encoder_outputs=None, **kwargs): - decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids) + def prepare_inputs_for_generation( + self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + ): + decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past) decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None input_dict = { "attention_mask": attention_mask, "decoder_attention_mask": decoder_attention_mask, "decoder_input_ids": decoder_inputs["input_ids"], "encoder_outputs": encoder_outputs, + "past_key_values": past, + "use_cache": use_cache, } - - # Ideally all models should have a :obj:`use_cache` - # leave following to ifs until all have it implemented - if "use_cache" in decoder_inputs: - input_dict["decoder_use_cache"] = decoder_inputs["use_cache"] - - if "past_key_values" in decoder_inputs: - input_dict["past_key_values"] = decoder_inputs["past_key_values"] - return input_dict def _reorder_cache(self, past, beam_idx): diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index d34bcd7165..f85be6449e 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -33,7 +33,7 @@ from ...file_utils import ( ) from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, - CausalLMOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, SequenceClassifierOutputWithPast, ) from ...modeling_utils import ( @@ -851,7 +851,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="gpt2", - output_type=CausalLMOutputWithPastAndCrossAttentions, + output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC, ) def forward( @@ -916,7 +916,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output - return CausalLMOutputWithPastAndCrossAttentions( + return CausalLMOutputWithCrossAttentions( loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index a81c3fa1f2..751a65959e 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -24,7 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...modeling_outputs import ( - BaseModelOutputWithCrossAttentions, + BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, MaskedLMOutput, TokenClassifierOutput, @@ -151,6 +151,8 @@ class LayoutLMSelfAttention(nn.Module): self.max_position_embeddings = config.max_position_embeddings self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + self.is_decoder = config.is_decoder + def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) @@ -163,6 +165,7 @@ class LayoutLMSelfAttention(nn.Module): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + past_key_value=None, output_attentions=False, ): mixed_query_layer = self.query(hidden_states) @@ -170,17 +173,37 @@ class LayoutLMSelfAttention(nn.Module): # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. - if encoder_hidden_states is not None: - mixed_key_layer = self.key(encoder_hidden_states) - mixed_value_layer = self.value(encoder_hidden_states) + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) else: - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -224,6 +247,9 @@ class LayoutLMSelfAttention(nn.Module): context_layer = context_layer.view(*new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) return outputs @@ -275,6 +301,7 @@ class LayoutLMAttention(nn.Module): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + past_key_value=None, output_attentions=False, ): self_outputs = self.self( @@ -283,6 +310,7 @@ class LayoutLMAttention(nn.Module): head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) @@ -343,36 +371,60 @@ class LayoutLMLayer(nn.Module): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + past_key_value=None, output_attentions=False, ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, ) attention_output = self_attention_outputs[0] - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: assert hasattr( self, "crossattention" ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, + cross_attn_past_key_value, output_attentions, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + return outputs def feed_forward_chunk(self, attention_output): @@ -395,6 +447,8 @@ class LayoutLMEncoder(nn.Module): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + past_key_values=None, + use_cache=None, output_attentions=False, output_hidden_states=False, return_dict=True, @@ -402,17 +456,19 @@ class LayoutLMEncoder(nn.Module): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - + past_key_value = past_key_values[i] if past_key_values is not None else None if getattr(self.config, "gradient_checkpointing", False): def create_custom_forward(module): def custom_forward(*inputs): - return module(*inputs, output_attentions) + return module(*inputs, past_key_value, output_attentions) return custom_forward @@ -431,9 +487,13 @@ class LayoutLMEncoder(nn.Module): layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, output_attentions, ) + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -445,11 +505,18 @@ class LayoutLMEncoder(nn.Module): if not return_dict: return tuple( v - for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions] + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None ) - return BaseModelOutputWithCrossAttentions( + return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 8cdc82ea06..5e14f05dbf 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -424,7 +424,6 @@ def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=Tru return attention_mask -# Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids def create_position_ids_from_input_ids(input_ids, padding_idx): """ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 09606953cc..3e9eca1767 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -29,7 +29,7 @@ from ...file_utils import ( replace_return_docstrings, ) from ...modeling_outputs import ( - BaseModelOutputWithCrossAttentions, + BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, MaskedLMOutput, @@ -91,25 +91,23 @@ class RobertaEmbeddings(nn.Module): config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx ) - def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): + def forward( + self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): if position_ids is None: if input_ids is not None: # Create the position ids from the input token ids. Any padded tokens remain padded. - position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx).to(input_ids.device) + position_ids = create_position_ids_from_input_ids( + input_ids, self.padding_idx, past_key_values_length + ).to(input_ids.device) else: position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) - # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward if input_ids is not None: input_shape = input_ids.size() else: input_shape = inputs_embeds.size()[:-1] - seq_length = input_shape[1] - - if position_ids is None: - position_ids = self.position_ids[:, :seq_length] - if token_type_ids is None: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) @@ -167,6 +165,8 @@ class RobertaSelfAttention(nn.Module): self.max_position_embeddings = config.max_position_embeddings self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + self.is_decoder = config.is_decoder + def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) @@ -179,6 +179,7 @@ class RobertaSelfAttention(nn.Module): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + past_key_value=None, output_attentions=False, ): mixed_query_layer = self.query(hidden_states) @@ -186,17 +187,37 @@ class RobertaSelfAttention(nn.Module): # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. - if encoder_hidden_states is not None: - mixed_key_layer = self.key(encoder_hidden_states) - mixed_value_layer = self.value(encoder_hidden_states) + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) else: - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -240,6 +261,9 @@ class RobertaSelfAttention(nn.Module): context_layer = context_layer.view(*new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) return outputs @@ -291,6 +315,7 @@ class RobertaAttention(nn.Module): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + past_key_value=None, output_attentions=False, ): self_outputs = self.self( @@ -299,6 +324,7 @@ class RobertaAttention(nn.Module): head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) @@ -359,36 +385,60 @@ class RobertaLayer(nn.Module): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + past_key_value=None, output_attentions=False, ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, ) attention_output = self_attention_outputs[0] - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: assert hasattr( self, "crossattention" ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, + cross_attn_past_key_value, output_attentions, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + return outputs def feed_forward_chunk(self, attention_output): @@ -411,6 +461,8 @@ class RobertaEncoder(nn.Module): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + past_key_values=None, + use_cache=None, output_attentions=False, output_hidden_states=False, return_dict=True, @@ -418,17 +470,19 @@ class RobertaEncoder(nn.Module): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - + past_key_value = past_key_values[i] if past_key_values is not None else None if getattr(self.config, "gradient_checkpointing", False): def create_custom_forward(module): def custom_forward(*inputs): - return module(*inputs, output_attentions) + return module(*inputs, past_key_value, output_attentions) return custom_forward @@ -447,9 +501,13 @@ class RobertaEncoder(nn.Module): layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, output_attentions, ) + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -461,11 +519,18 @@ class RobertaEncoder(nn.Module): if not return_dict: return tuple( v - for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions] + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None ) - return BaseModelOutputWithCrossAttentions( + return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -646,6 +711,8 @@ class RobertaModel(RobertaPreTrainedModel): inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, + past_key_values=None, + use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, @@ -658,26 +725,44 @@ class RobertaModel(RobertaPreTrainedModel): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if not self.config.is_decoder: + use_cache = False if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() + batch_size, seq_length = input_shape elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape else: raise ValueError("You have to specify either input_ids or inputs_embeds") device = input_ids.device if input_ids is not None else inputs_embeds.device + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if attention_mask is None: - attention_mask = torch.ones(input_shape, device=device) + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) if token_type_ids is None: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) @@ -704,7 +789,11 @@ class RobertaModel(RobertaPreTrainedModel): head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) embedding_output = self.embeddings( - input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, ) encoder_outputs = self.encoder( embedding_output, @@ -712,6 +801,8 @@ class RobertaModel(RobertaPreTrainedModel): head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -725,6 +816,7 @@ class RobertaModel(RobertaPreTrainedModel): return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, cross_attentions=encoder_outputs.cross_attentions, @@ -768,6 +860,8 @@ class RobertaForCausalLM(RobertaPreTrainedModel): encoder_hidden_states=None, encoder_attention_mask=None, labels=None, + past_key_values=None, + use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, @@ -787,6 +881,15 @@ class RobertaForCausalLM(RobertaPreTrainedModel): Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). Returns: @@ -806,6 +909,8 @@ class RobertaForCausalLM(RobertaPreTrainedModel): >>> prediction_logits = outputs.logits """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False outputs = self.roberta( input_ids, @@ -816,6 +921,8 @@ class RobertaForCausalLM(RobertaPreTrainedModel): inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -839,20 +946,30 @@ class RobertaForCausalLM(RobertaPreTrainedModel): return CausalLMOutputWithCrossAttentions( loss=lm_loss, logits=prediction_scores, + past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, cross_attentions=outputs.cross_attentions, ) - def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): input_shape = input_ids.shape - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + return {"input_ids": input_ids, "attention_mask": attention_mask} + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING) class RobertaForMaskedLM(RobertaPreTrainedModel): @@ -1357,7 +1474,7 @@ class RobertaForQuestionAnswering(RobertaPreTrainedModel): ) -def create_position_ids_from_input_ids(input_ids, padding_idx): +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): """ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols are ignored. This is modified from fairseq's `utils.make_positions`. @@ -1369,5 +1486,5 @@ def create_position_ids_from_input_ids(input_ids, padding_idx): """ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. mask = input_ids.ne(padding_idx).int() - incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask return incremental_indices.long() + padding_idx diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 3c2a2d406a..0232363daa 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -347,6 +347,7 @@ class TapasSelfAttention(nn.Module): self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.is_decoder = config.is_decoder def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -360,6 +361,7 @@ class TapasSelfAttention(nn.Module): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + past_key_value=None, output_attentions=False, ): mixed_query_layer = self.query(hidden_states) @@ -367,17 +369,30 @@ class TapasSelfAttention(nn.Module): # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. - if encoder_hidden_states is not None: - mixed_key_layer = self.key(encoder_hidden_states) - mixed_value_layer = self.value(encoder_hidden_states) + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) else: - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) + + if self.is_decoder: + past_key_value = (key_layer, value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -404,6 +419,8 @@ class TapasSelfAttention(nn.Module): context_layer = context_layer.view(*new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + if self.is_decoder: + outputs = outputs + (past_key_value,) return outputs @@ -455,6 +472,7 @@ class TapasAttention(nn.Module): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + past_key_value=None, output_attentions=False, ): self_outputs = self.self( @@ -463,6 +481,7 @@ class TapasAttention(nn.Module): head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_value, output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) @@ -523,36 +542,60 @@ class TapasLayer(nn.Module): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + past_key_value=None, output_attentions=False, ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, ) attention_output = self_attention_outputs[0] - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: assert hasattr( self, "crossattention" ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, + cross_attn_past_key_value, output_attentions, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + return outputs def feed_forward_chunk(self, attention_output): @@ -574,6 +617,8 @@ class TapasEncoder(nn.Module): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, + past_key_values=None, + use_cache=None, output_attentions=False, output_hidden_states=False, return_dict=True, @@ -590,7 +635,7 @@ class TapasEncoder(nn.Module): def create_custom_forward(module): def custom_forward(*inputs): - return module(*inputs, output_attentions) + return module(*inputs, past_key_values, output_attentions) return custom_forward @@ -609,6 +654,7 @@ class TapasEncoder(nn.Module): layer_head_mask, encoder_hidden_states, encoder_attention_mask, + past_key_values, output_attentions, ) hidden_states = layer_outputs[0] diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index b870b38d9c..f38816e095 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -150,7 +150,7 @@ class BartModelTester: input_ids = inputs_dict["input_ids"] # first forward pass - outputs = model(input_ids, use_cache=True) + outputs = model(input_ids, attention_mask=inputs_dict["attention_mask"], use_cache=True) output, past_key_values = outputs.to_tuple() diff --git a/tests/test_modeling_bert.py b/tests/test_modeling_bert.py index 440dfbb18e..7d2bd9bc2a 100755 --- a/tests/test_modeling_bert.py +++ b/tests/test_modeling_bert.py @@ -260,6 +260,66 @@ class BertModelTester: ) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.is_decoder = True + config.add_cross_attention = True + model = BertLMHeadModel(config=config).to(torch_device).eval() + + # first forward pass + outputs = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_hidden_states=True, + )["hidden_states"][0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + )["hidden_states"][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + def create_and_check_for_next_sequence_prediction( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): @@ -454,6 +514,10 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_model_for_causal_lm_as_decoder(*config_and_inputs) + def test_decoder_model_past_with_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + def test_for_multiple_choice(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs) diff --git a/tests/test_modeling_bert_generation.py b/tests/test_modeling_bert_generation.py index 16609de6e2..aa22970a96 100755 --- a/tests/test_modeling_bert_generation.py +++ b/tests/test_modeling_bert_generation.py @@ -25,6 +25,8 @@ from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, r if is_torch_available(): + import torch + from transformers import BertGenerationConfig, BertGenerationDecoder, BertGenerationEncoder @@ -156,6 +158,64 @@ class BertGenerationEncoderTester: ) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + input_mask, + token_labels, + encoder_hidden_states, + encoder_attention_mask, + **kwargs, + ): + config.is_decoder = True + config.add_cross_attention = True + model = BertGenerationDecoder(config=config).to(torch_device).eval() + + # first forward pass + outputs = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_hidden_states=True, + )["hidden_states"][0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + )["hidden_states"][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + def create_and_check_for_causal_lm( self, config, @@ -203,6 +263,10 @@ class BertGenerationEncoderTest(ModelTesterMixin, GenerationTesterMixin, unittes config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) + def test_decoder_model_past_with_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + def test_model_as_decoder_with_default_input_mask(self): # This regression test was failing with PyTorch < 1.3 ( diff --git a/tests/test_modeling_roberta.py b/tests/test_modeling_roberta.py index c02f5c1b7f..be675eda6d 100644 --- a/tests/test_modeling_roberta.py +++ b/tests/test_modeling_roberta.py @@ -198,6 +198,74 @@ class RobertaModelTester: result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.is_decoder = True + config.add_cross_attention = True + model = RobertaForCausalLM(config=config).to(torch_device).eval() + + # make sure that ids don't start with pad token + mask = input_ids.ne(config.pad_token_id).long() + input_ids = input_ids * mask + + # first forward pass + outputs = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + + # make sure that ids don't start with pad token + mask = next_tokens.ne(config.pad_token_id).long() + next_tokens = next_tokens * mask + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_hidden_states=True, + )["hidden_states"][0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + )["hidden_states"][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + def create_and_check_for_masked_lm( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): @@ -337,6 +405,10 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) + def test_decoder_model_past_with_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + def test_for_masked_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)