Update past_key_values in GPT-2 (#9596)

* Update past_key_values in gpt2 (#9391)

* Update generation_utils, and rename some items

* Update modeling_gpt2 to avoid an error in gradient_checkpointing

* Remove 'reorder_cache' from util and add variations to XLNet, TransfoXL, GPT-2

* Change the location of '_reorder_cache' in modeling files

* Add '_reorder_cache' in modeling_ctrl

* Fix a bug of my last commit in CTRL

* Add '_reorder_cache' to GPT2DoubleHeadsModel

* Manage 'use_cache' in config of test_modeling_gpt2

* Clean up the doc string

* Update src/transformers/models/gpt2/modeling_gpt2.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Fix the doc string (GPT-2, CTRL)

* improve gradient_checkpointing_behavior

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Yusuke Mori
2021-01-20 00:00:15 +09:00
committed by GitHub
parent 97b787fb4e
commit b020a736c3
19 changed files with 164 additions and 67 deletions

View File

@@ -503,18 +503,10 @@ class GenerationMixin:
return model_kwargs return model_kwargs
@staticmethod def _reorder_cache(self, past, beam_idx):
def _reorder_cache(past: Tuple[torch.Tensor], beam_idx: torch.Tensor) -> Tuple[torch.Tensor]: raise NotImplementedError(
""" f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to enable beam search for {self.__class__}"
This function is used to re-order the :obj:`past_key_values` or :obj:`mems` cache if )
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` or :obj:`mems` with the correct beam_idx at every
generation step.
For custom re-ordering of :obj:`past_key_values` or :obj:`mems`, the function should be implemented in
subclasses of :class:`~transformers.PreTrainedModel`.
"""
return tuple(layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in past)
def _get_logits_warper( def _get_logits_warper(
self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None

View File

@@ -774,7 +774,7 @@ class BartEncoder(BartPretrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
@@ -993,11 +993,13 @@ class BartDecoder(BartPretrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache: if use_cache:
raise ValueError( logger.warn(
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
) )
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):

View File

@@ -539,7 +539,14 @@ class BertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
)
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):

View File

@@ -733,7 +733,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
@@ -955,11 +955,13 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache: if use_cache:
raise ValueError( logger.warn(
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
) )
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):

View File

@@ -735,7 +735,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
@@ -955,11 +955,13 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache: if use_cache:
raise ValueError( logger.warn(
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
) )
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):

View File

@@ -15,6 +15,8 @@
# limitations under the License. # limitations under the License.
""" PyTorch CTRL model.""" """ PyTorch CTRL model."""
from typing import Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -262,7 +264,7 @@ CTRL_INPUTS_DOCSTRING = r"""
details. details.
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`): past_key_values (:obj:`Tuple[Tuple[torch.FloatTensor]]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see
:obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which :obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
have their past given to this model should not be passed as input ids as they have already been computed. have their past given to this model should not be passed as input ids as they have already been computed.
@@ -389,7 +391,7 @@ class CTRLModel(CTRLPreTrainedModel):
if past_key_values is None: if past_key_values is None:
past_length = 0 past_length = 0
past_key_values = [None] * len(self.h) past_key_values = tuple([None] * len(self.h))
else: else:
past_length = past_key_values[0][0].size(-2) past_length = past_key_values[0][0].size(-2)
if position_ids is None: if position_ids is None:
@@ -575,6 +577,18 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the :obj:`past_key_values` cache if
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)
@add_start_docstrings( @add_start_docstrings(
""" """

View File

@@ -536,7 +536,14 @@ class ElectraEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
)
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):

View File

@@ -17,7 +17,7 @@
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -233,7 +233,7 @@ class Attention(nn.Module):
value = torch.cat((past_value, value), dim=-2) value = torch.cat((past_value, value), dim=-2)
if use_cache is True: if use_cache is True:
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking present = (key.transpose(-2, -1), value) # transpose to have same shapes
else: else:
present = None present = None
@@ -370,9 +370,9 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`): mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, Tuple of length :obj:`config.n_layers`, containing tuples of tensors of shape :obj:`(batch_size, num_heads,
batch_size, num_heads, sequence_length, embed_size_per_head)`). sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
:obj:`past_key_values` input) to speed up sequential decoding. :obj:`past_key_values` input) to speed up sequential decoding.
@@ -393,7 +393,7 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
mc_loss: Optional[torch.FloatTensor] = None mc_loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None logits: torch.FloatTensor = None
mc_logits: torch.FloatTensor = None mc_logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
@@ -419,7 +419,7 @@ GPT2_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`): input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
:obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else :obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else
``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input ``past_key_values[0][0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input
sequence tokens in the vocabulary. sequence tokens in the vocabulary.
If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be
@@ -430,7 +430,7 @@ GPT2_INPUTS_DOCSTRING = r"""
details. details.
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers`):
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
:obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which :obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
have their past given to this model should not be passed as ``input_ids`` as they have already been have their past given to this model should not be passed as ``input_ids`` as they have already been
@@ -640,7 +640,7 @@ class GPT2Model(GPT2PreTrainedModel):
if past_key_values is None: if past_key_values is None:
past_length = 0 past_length = 0
past_key_values = [None] * len(self.h) past_key_values = tuple([None] * len(self.h))
else: else:
past_length = past_key_values[0][0].size(-2) past_length = past_key_values[0][0].size(-2)
if position_ids is None: if position_ids is None:
@@ -708,7 +708,7 @@ class GPT2Model(GPT2PreTrainedModel):
torch.cuda.set_device(hidden_states.device) torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct) # Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None: if layer_past is not None:
layer_past = layer_past.to(hidden_states.device) layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
# Ensure that attention_mask is always on the same device as hidden_states # Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device) attention_mask = attention_mask.to(hidden_states.device)
@@ -717,19 +717,25 @@ class GPT2Model(GPT2PreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
)
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
# checkpointing only works with tuple returns, not with lists # None for past_key_value
return tuple(output for output in module(*inputs, use_cache, output_attentions)) return module(*inputs, use_cache, output_attentions)
return custom_forward return custom_forward
outputs = torch.utils.checkpoint.checkpoint( outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), create_custom_forward(block),
hidden_states, hidden_states,
layer_past, None,
attention_mask, attention_mask,
head_mask[i], head_mask[i],
encoder_hidden_states, encoder_hidden_states,
@@ -932,6 +938,18 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
cross_attentions=transformer_outputs.cross_attentions, cross_attentions=transformer_outputs.cross_attentions,
) )
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the :obj:`past_key_values` cache if
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)
@add_start_docstrings( @add_start_docstrings(
""" """
@@ -1095,6 +1113,18 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the :obj:`past_key_values` cache if
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)
@add_start_docstrings( @add_start_docstrings(
""" """

View File

@@ -466,7 +466,14 @@ class LayoutLMEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
)
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):

View File

@@ -1695,7 +1695,7 @@ class LEDEncoder(LEDPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None, None) layer_outputs = (None, None, None)
else: else:
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
@@ -1920,11 +1920,13 @@ class LEDDecoder(LEDPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache: if use_cache:
raise ValueError( logger.warn(
"When using `gradient_checkpointing`, make sure that `use_cache=False` and `config.use_cache=False`." "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
) )
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):

View File

@@ -1226,7 +1226,7 @@ class LongformerEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):

View File

@@ -742,7 +742,7 @@ class MarianEncoder(MarianPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
@@ -958,11 +958,13 @@ class MarianDecoder(MarianPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache: if use_cache:
raise ValueError( logger.warn(
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
) )
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):

View File

@@ -780,7 +780,7 @@ class MBartEncoder(MBartPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
@@ -1002,11 +1002,13 @@ class MBartDecoder(MBartPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache: if use_cache:
raise ValueError( logger.warn(
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
) )
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):

View File

@@ -746,7 +746,7 @@ class PegasusEncoder(PegasusPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
@@ -966,11 +966,13 @@ class PegasusDecoder(PegasusPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache: if use_cache:
raise ValueError( logger.warn(
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
) )
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):

View File

@@ -479,7 +479,14 @@ class RobertaEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
)
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):

View File

@@ -1137,6 +1137,15 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
self.crit.cutoff_ends = [0] + new_cutoffs self.crit.cutoff_ends = [0] + new_cutoffs
self.crit.n_token = new_num_tokens self.crit.n_token = new_num_tokens
@staticmethod
def _reorder_cache(mems: List[torch.Tensor], beam_idx: torch.Tensor) -> List[torch.Tensor]:
"""
This function is used to re-order the :obj:`mems` cache if :meth:`~transformers.PretrainedModel.beam_search` or
:meth:`~transformers.PretrainedModel.beam_sample` is called. This is required to match :obj:`mems` with the
correct beam_idx at every generation step.
"""
return [layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in mems]
@add_start_docstrings( @add_start_docstrings(
""" """

View File

@@ -1462,6 +1462,15 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
@staticmethod
def _reorder_cache(mems: List[torch.Tensor], beam_idx: torch.Tensor) -> List[torch.Tensor]:
"""
This function is used to re-order the :obj:`mems` cache if :meth:`~transformers.PretrainedModel.beam_search` or
:meth:`~transformers.PretrainedModel.beam_sample` is called. This is required to match :obj:`mems` with the
correct beam_idx at every generation step.
"""
return [layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in mems]
@add_start_docstrings( @add_start_docstrings(
""" """

View File

@@ -526,7 +526,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
@@ -2182,7 +2182,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
@@ -2374,11 +2374,11 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache: if use_cache:
raise ValueError( logger.warn("`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`...")
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." use_cache = False
)
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):

View File

@@ -131,6 +131,7 @@ class GPT2ModelTester:
n_ctx=self.max_position_embeddings, n_ctx=self.max_position_embeddings,
# type_vocab_size=self.type_vocab_size, # type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range, # initializer_range=self.initializer_range,
use_cache=not gradient_checkpointing,
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,