From b020a736c374460af1b34267283f957988350630 Mon Sep 17 00:00:00 2001 From: Yusuke Mori Date: Wed, 20 Jan 2021 00:00:15 +0900 Subject: [PATCH] Update `past_key_values` in GPT-2 (#9596) * Update past_key_values in gpt2 (#9391) * Update generation_utils, and rename some items * Update modeling_gpt2 to avoid an error in gradient_checkpointing * Remove 'reorder_cache' from util and add variations to XLNet, TransfoXL, GPT-2 * Change the location of '_reorder_cache' in modeling files * Add '_reorder_cache' in modeling_ctrl * Fix a bug of my last commit in CTRL * Add '_reorder_cache' to GPT2DoubleHeadsModel * Manage 'use_cache' in config of test_modeling_gpt2 * Clean up the doc string * Update src/transformers/models/gpt2/modeling_gpt2.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Fix the doc string (GPT-2, CTRL) * improve gradient_checkpointing_behavior Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen --- src/transformers/generation_utils.py | 16 ++--- src/transformers/models/bart/modeling_bart.py | 10 ++-- src/transformers/models/bert/modeling_bert.py | 9 ++- .../models/blenderbot/modeling_blenderbot.py | 10 ++-- .../modeling_blenderbot_small.py | 10 ++-- src/transformers/models/ctrl/modeling_ctrl.py | 18 +++++- .../models/electra/modeling_electra.py | 9 ++- src/transformers/models/gpt2/modeling_gpt2.py | 58 ++++++++++++++----- .../models/layoutlm/modeling_layoutlm.py | 9 ++- src/transformers/models/led/modeling_led.py | 10 ++-- .../models/longformer/modeling_longformer.py | 2 +- .../models/marian/modeling_marian.py | 10 ++-- .../models/mbart/modeling_mbart.py | 10 ++-- .../models/pegasus/modeling_pegasus.py | 10 ++-- .../models/roberta/modeling_roberta.py | 9 ++- .../models/transfo_xl/modeling_transfo_xl.py | 9 +++ .../models/xlnet/modeling_xlnet.py | 9 +++ ...ng_{{cookiecutter.lowercase_modelname}}.py | 12 ++-- tests/test_modeling_gpt2.py | 1 + 19 files changed, 164 insertions(+), 67 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 6fcf9e5bab..3f933940c6 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -503,18 +503,10 @@ class GenerationMixin: return model_kwargs - @staticmethod - def _reorder_cache(past: Tuple[torch.Tensor], beam_idx: torch.Tensor) -> Tuple[torch.Tensor]: - """ - This function is used to re-order the :obj:`past_key_values` or :obj:`mems` cache if - :meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is - called. This is required to match :obj:`past_key_values` or :obj:`mems` with the correct beam_idx at every - generation step. - - For custom re-ordering of :obj:`past_key_values` or :obj:`mems`, the function should be implemented in - subclasses of :class:`~transformers.PreTrainedModel`. - """ - return tuple(layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in past) + def _reorder_cache(self, past, beam_idx): + raise NotImplementedError( + f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to enable beam search for {self.__class__}" + ) def _get_logits_warper( self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index b34d81741d..6a6fac4690 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -774,7 +774,7 @@ class BartEncoder(BartPretrainedModel): if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -993,11 +993,13 @@ class BartDecoder(BartPretrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: - raise ValueError( - "When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." + logger.warn( + "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." ) + use_cache = False def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 590363a124..72c795e621 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -539,7 +539,14 @@ class BertEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False): + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + ) + use_cache = False def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index a47017278e..2b962c5e8f 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -733,7 +733,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel): if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -955,11 +955,13 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: - raise ValueError( - "When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." + logger.warn( + "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." ) + use_cache = False def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index dd8e1020cc..8003505d59 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -735,7 +735,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel): if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -955,11 +955,13 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: - raise ValueError( - "When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." + logger.warn( + "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." ) + use_cache = False def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index 76f0402d77..da50095408 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -15,6 +15,8 @@ # limitations under the License. """ PyTorch CTRL model.""" +from typing import Tuple + import numpy as np import torch import torch.nn as nn @@ -262,7 +264,7 @@ CTRL_INPUTS_DOCSTRING = r""" details. `What are input IDs? <../glossary.html#input-ids>`__ - past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`): + past_key_values (:obj:`Tuple[Tuple[torch.FloatTensor]]` of length :obj:`config.n_layers`): Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see :obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which have their past given to this model should not be passed as input ids as they have already been computed. @@ -389,7 +391,7 @@ class CTRLModel(CTRLPreTrainedModel): if past_key_values is None: past_length = 0 - past_key_values = [None] * len(self.h) + past_key_values = tuple([None] * len(self.h)) else: past_length = past_key_values[0][0].size(-2) if position_ids is None: @@ -575,6 +577,18 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): attentions=transformer_outputs.attentions, ) + @staticmethod + def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the :obj:`past_key_values` cache if + :meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) + @add_start_docstrings( """ diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 1f83d934e4..0374871a77 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -536,7 +536,14 @@ class ElectraEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False): + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + ) + use_cache = False def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 25ca390fa3..175f9b1c42 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -17,7 +17,7 @@ import os from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import Optional, Tuple import torch import torch.nn as nn @@ -233,7 +233,7 @@ class Attention(nn.Module): value = torch.cat((past_value, value), dim=-2) if use_cache is True: - present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking + present = (key.transpose(-2, -1), value) # transpose to have same shapes else: present = None @@ -370,9 +370,9 @@ class GPT2DoubleHeadsModelOutput(ModelOutput): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`): Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). - past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): - List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, - batch_size, num_heads, sequence_length, embed_size_per_head)`). + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + Tuple of length :obj:`config.n_layers`, containing tuples of tensors of shape :obj:`(batch_size, num_heads, + sequence_length, embed_size_per_head)`). Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding. @@ -393,7 +393,7 @@ class GPT2DoubleHeadsModelOutput(ModelOutput): mc_loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None mc_logits: torch.FloatTensor = None - past_key_values: Optional[List[torch.FloatTensor]] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None @@ -419,7 +419,7 @@ GPT2_INPUTS_DOCSTRING = r""" Args: input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`): :obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else - ``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input + ``past_key_values[0][0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input sequence tokens in the vocabulary. If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be @@ -430,7 +430,7 @@ GPT2_INPUTS_DOCSTRING = r""" details. `What are input IDs? <../glossary.html#input-ids>`__ - past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`): + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers`): Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see :obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which have their past given to this model should not be passed as ``input_ids`` as they have already been @@ -640,7 +640,7 @@ class GPT2Model(GPT2PreTrainedModel): if past_key_values is None: past_length = 0 - past_key_values = [None] * len(self.h) + past_key_values = tuple([None] * len(self.h)) else: past_length = past_key_values[0][0].size(-2) if position_ids is None: @@ -708,7 +708,7 @@ class GPT2Model(GPT2PreTrainedModel): torch.cuda.set_device(hidden_states.device) # Ensure layer_past is on same device as hidden_states (might not be correct) if layer_past is not None: - layer_past = layer_past.to(hidden_states.device) + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) # Ensure that attention_mask is always on the same device as hidden_states if attention_mask is not None: attention_mask = attention_mask.to(hidden_states.device) @@ -717,19 +717,25 @@ class GPT2Model(GPT2PreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + ) + use_cache = False def create_custom_forward(module): def custom_forward(*inputs): - # checkpointing only works with tuple returns, not with lists - return tuple(output for output in module(*inputs, use_cache, output_attentions)) + # None for past_key_value + return module(*inputs, use_cache, output_attentions) return custom_forward outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, - layer_past, + None, attention_mask, head_mask[i], encoder_hidden_states, @@ -932,6 +938,18 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): cross_attentions=transformer_outputs.cross_attentions, ) + @staticmethod + def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the :obj:`past_key_values` cache if + :meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) + @add_start_docstrings( """ @@ -1095,6 +1113,18 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): attentions=transformer_outputs.attentions, ) + @staticmethod + def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the :obj:`past_key_values` cache if + :meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) + @add_start_docstrings( """ diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index df4e4cca21..31c46fd9f9 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -466,7 +466,14 @@ class LayoutLMEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False): + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + ) + use_cache = False def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 3bba188969..c79cc2a0e9 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -1695,7 +1695,7 @@ class LEDEncoder(LEDPreTrainedModel): if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None, None) else: - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -1920,11 +1920,13 @@ class LEDDecoder(LEDPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: - raise ValueError( - "When using `gradient_checkpointing`, make sure that `use_cache=False` and `config.use_cache=False`." + logger.warn( + "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." ) + use_cache = False def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 25cbb30c4d..2754afc345 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -1226,7 +1226,7 @@ class LongformerEncoder(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 6e8bdf0d0c..0ed8e1cfd1 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -742,7 +742,7 @@ class MarianEncoder(MarianPreTrainedModel): if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -958,11 +958,13 @@ class MarianDecoder(MarianPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: - raise ValueError( - "When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." + logger.warn( + "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." ) + use_cache = False def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index ac631057d8..9c5f245b6c 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -780,7 +780,7 @@ class MBartEncoder(MBartPreTrainedModel): if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -1002,11 +1002,13 @@ class MBartDecoder(MBartPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: - raise ValueError( - "When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." + logger.warn( + "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." ) + use_cache = False def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 13deb70da9..ecae05aab5 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -746,7 +746,7 @@ class PegasusEncoder(PegasusPreTrainedModel): if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -966,11 +966,13 @@ class PegasusDecoder(PegasusPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: - raise ValueError( - "When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." + logger.warn( + "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." ) + use_cache = False def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index f36e7673b9..3213ff488e 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -479,7 +479,14 @@ class RobertaEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False): + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + ) + use_cache = False def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/transfo_xl/modeling_transfo_xl.py b/src/transformers/models/transfo_xl/modeling_transfo_xl.py index 3e579566f6..01d5cf0454 100644 --- a/src/transformers/models/transfo_xl/modeling_transfo_xl.py +++ b/src/transformers/models/transfo_xl/modeling_transfo_xl.py @@ -1137,6 +1137,15 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): self.crit.cutoff_ends = [0] + new_cutoffs self.crit.n_token = new_num_tokens + @staticmethod + def _reorder_cache(mems: List[torch.Tensor], beam_idx: torch.Tensor) -> List[torch.Tensor]: + """ + This function is used to re-order the :obj:`mems` cache if :meth:`~transformers.PretrainedModel.beam_search` or + :meth:`~transformers.PretrainedModel.beam_sample` is called. This is required to match :obj:`mems` with the + correct beam_idx at every generation step. + """ + return [layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in mems] + @add_start_docstrings( """ diff --git a/src/transformers/models/xlnet/modeling_xlnet.py b/src/transformers/models/xlnet/modeling_xlnet.py index cf8d67695c..e873996d36 100755 --- a/src/transformers/models/xlnet/modeling_xlnet.py +++ b/src/transformers/models/xlnet/modeling_xlnet.py @@ -1462,6 +1462,15 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): attentions=transformer_outputs.attentions, ) + @staticmethod + def _reorder_cache(mems: List[torch.Tensor], beam_idx: torch.Tensor) -> List[torch.Tensor]: + """ + This function is used to re-order the :obj:`mems` cache if :meth:`~transformers.PretrainedModel.beam_search` or + :meth:`~transformers.PretrainedModel.beam_sample` is called. This is required to match :obj:`mems` with the + correct beam_idx at every generation step. + """ + return [layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in mems] + @add_start_docstrings( """ diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 5a12c9b795..9148157b7d 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -526,7 +526,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -2182,7 +2182,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -2374,11 +2374,11 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: - raise ValueError( - "When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." - ) + logger.warn("`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`...") + use_cache = False def create_custom_forward(module): def custom_forward(*inputs): diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index f60fa45bcb..bf70492967 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -131,6 +131,7 @@ class GPT2ModelTester: n_ctx=self.max_position_embeddings, # type_vocab_size=self.type_vocab_size, # initializer_range=self.initializer_range, + use_cache=not gradient_checkpointing, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id,