Update past_key_values in GPT-2 (#9596)
* Update past_key_values in gpt2 (#9391) * Update generation_utils, and rename some items * Update modeling_gpt2 to avoid an error in gradient_checkpointing * Remove 'reorder_cache' from util and add variations to XLNet, TransfoXL, GPT-2 * Change the location of '_reorder_cache' in modeling files * Add '_reorder_cache' in modeling_ctrl * Fix a bug of my last commit in CTRL * Add '_reorder_cache' to GPT2DoubleHeadsModel * Manage 'use_cache' in config of test_modeling_gpt2 * Clean up the doc string * Update src/transformers/models/gpt2/modeling_gpt2.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Fix the doc string (GPT-2, CTRL) * improve gradient_checkpointing_behavior Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -503,18 +503,10 @@ class GenerationMixin:
|
||||
|
||||
return model_kwargs
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past: Tuple[torch.Tensor], beam_idx: torch.Tensor) -> Tuple[torch.Tensor]:
|
||||
"""
|
||||
This function is used to re-order the :obj:`past_key_values` or :obj:`mems` cache if
|
||||
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
|
||||
called. This is required to match :obj:`past_key_values` or :obj:`mems` with the correct beam_idx at every
|
||||
generation step.
|
||||
|
||||
For custom re-ordering of :obj:`past_key_values` or :obj:`mems`, the function should be implemented in
|
||||
subclasses of :class:`~transformers.PreTrainedModel`.
|
||||
"""
|
||||
return tuple(layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in past)
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
raise NotImplementedError(
|
||||
f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to enable beam search for {self.__class__}"
|
||||
)
|
||||
|
||||
def _get_logits_warper(
|
||||
self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None
|
||||
|
||||
@@ -774,7 +774,7 @@ class BartEncoder(BartPretrainedModel):
|
||||
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if getattr(self.config, "gradient_checkpointing", False):
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
@@ -993,11 +993,13 @@ class BartDecoder(BartPretrainedModel):
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if getattr(self.config, "gradient_checkpointing", False):
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
if use_cache:
|
||||
raise ValueError(
|
||||
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`."
|
||||
logger.warn(
|
||||
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
|
||||
@@ -539,7 +539,14 @@ class BertEncoder(nn.Module):
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
if getattr(self.config, "gradient_checkpointing", False):
|
||||
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
if use_cache:
|
||||
logger.warn(
|
||||
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
|
||||
@@ -733,7 +733,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
|
||||
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if getattr(self.config, "gradient_checkpointing", False):
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
@@ -955,11 +955,13 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if getattr(self.config, "gradient_checkpointing", False):
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
if use_cache:
|
||||
raise ValueError(
|
||||
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`."
|
||||
logger.warn(
|
||||
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
|
||||
@@ -735,7 +735,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
|
||||
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if getattr(self.config, "gradient_checkpointing", False):
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
@@ -955,11 +955,13 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if getattr(self.config, "gradient_checkpointing", False):
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
if use_cache:
|
||||
raise ValueError(
|
||||
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`."
|
||||
logger.warn(
|
||||
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
|
||||
@@ -15,6 +15,8 @@
|
||||
# limitations under the License.
|
||||
""" PyTorch CTRL model."""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -262,7 +264,7 @@ CTRL_INPUTS_DOCSTRING = r"""
|
||||
details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
||||
past_key_values (:obj:`Tuple[Tuple[torch.FloatTensor]]` of length :obj:`config.n_layers`):
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see
|
||||
:obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
|
||||
have their past given to this model should not be passed as input ids as they have already been computed.
|
||||
@@ -389,7 +391,7 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past_key_values = [None] * len(self.h)
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
if position_ids is None:
|
||||
@@ -575,6 +577,18 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
|
||||
"""
|
||||
This function is used to re-order the :obj:`past_key_values` cache if
|
||||
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
|
||||
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
||||
"""
|
||||
return tuple(
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
||||
for layer_past in past
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
||||
@@ -536,7 +536,14 @@ class ElectraEncoder(nn.Module):
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
if getattr(self.config, "gradient_checkpointing", False):
|
||||
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
if use_cache:
|
||||
logger.warn(
|
||||
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -233,7 +233,7 @@ class Attention(nn.Module):
|
||||
value = torch.cat((past_value, value), dim=-2)
|
||||
|
||||
if use_cache is True:
|
||||
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
|
||||
present = (key.transpose(-2, -1), value) # transpose to have same shapes
|
||||
else:
|
||||
present = None
|
||||
|
||||
@@ -370,9 +370,9 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
|
||||
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
|
||||
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2,
|
||||
batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
||||
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||
Tuple of length :obj:`config.n_layers`, containing tuples of tensors of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, embed_size_per_head)`).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
|
||||
:obj:`past_key_values` input) to speed up sequential decoding.
|
||||
@@ -393,7 +393,7 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
|
||||
mc_loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
mc_logits: torch.FloatTensor = None
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
@@ -419,7 +419,7 @@ GPT2_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
|
||||
:obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else
|
||||
``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input
|
||||
``past_key_values[0][0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input
|
||||
sequence tokens in the vocabulary.
|
||||
|
||||
If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be
|
||||
@@ -430,7 +430,7 @@ GPT2_INPUTS_DOCSTRING = r"""
|
||||
details.
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
||||
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers`):
|
||||
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
|
||||
:obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
|
||||
have their past given to this model should not be passed as ``input_ids`` as they have already been
|
||||
@@ -640,7 +640,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past_key_values = [None] * len(self.h)
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
if position_ids is None:
|
||||
@@ -708,7 +708,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
torch.cuda.set_device(hidden_states.device)
|
||||
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
||||
if layer_past is not None:
|
||||
layer_past = layer_past.to(hidden_states.device)
|
||||
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
||||
# Ensure that attention_mask is always on the same device as hidden_states
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
@@ -717,19 +717,25 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if getattr(self.config, "gradient_checkpointing", False):
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
if use_cache:
|
||||
logger.warn(
|
||||
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# checkpointing only works with tuple returns, not with lists
|
||||
return tuple(output for output in module(*inputs, use_cache, output_attentions))
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
layer_past,
|
||||
None,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states,
|
||||
@@ -932,6 +938,18 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
cross_attentions=transformer_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
|
||||
"""
|
||||
This function is used to re-order the :obj:`past_key_values` cache if
|
||||
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
|
||||
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
||||
"""
|
||||
return tuple(
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
||||
for layer_past in past
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
@@ -1095,6 +1113,18 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
|
||||
"""
|
||||
This function is used to re-order the :obj:`past_key_values` cache if
|
||||
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
|
||||
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
||||
"""
|
||||
return tuple(
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
||||
for layer_past in past
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
||||
@@ -466,7 +466,14 @@ class LayoutLMEncoder(nn.Module):
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
if getattr(self.config, "gradient_checkpointing", False):
|
||||
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
if use_cache:
|
||||
logger.warn(
|
||||
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
|
||||
@@ -1695,7 +1695,7 @@ class LEDEncoder(LEDPreTrainedModel):
|
||||
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
||||
layer_outputs = (None, None, None)
|
||||
else:
|
||||
if getattr(self.config, "gradient_checkpointing", False):
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
@@ -1920,11 +1920,13 @@ class LEDDecoder(LEDPreTrainedModel):
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if getattr(self.config, "gradient_checkpointing", False):
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
if use_cache:
|
||||
raise ValueError(
|
||||
"When using `gradient_checkpointing`, make sure that `use_cache=False` and `config.use_cache=False`."
|
||||
logger.warn(
|
||||
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
|
||||
@@ -1226,7 +1226,7 @@ class LongformerEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if getattr(self.config, "gradient_checkpointing", False):
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
|
||||
@@ -742,7 +742,7 @@ class MarianEncoder(MarianPreTrainedModel):
|
||||
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if getattr(self.config, "gradient_checkpointing", False):
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
@@ -958,11 +958,13 @@ class MarianDecoder(MarianPreTrainedModel):
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if getattr(self.config, "gradient_checkpointing", False):
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
if use_cache:
|
||||
raise ValueError(
|
||||
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`."
|
||||
logger.warn(
|
||||
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
|
||||
@@ -780,7 +780,7 @@ class MBartEncoder(MBartPreTrainedModel):
|
||||
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if getattr(self.config, "gradient_checkpointing", False):
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
@@ -1002,11 +1002,13 @@ class MBartDecoder(MBartPreTrainedModel):
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if getattr(self.config, "gradient_checkpointing", False):
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
if use_cache:
|
||||
raise ValueError(
|
||||
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`."
|
||||
logger.warn(
|
||||
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
|
||||
@@ -746,7 +746,7 @@ class PegasusEncoder(PegasusPreTrainedModel):
|
||||
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if getattr(self.config, "gradient_checkpointing", False):
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
@@ -966,11 +966,13 @@ class PegasusDecoder(PegasusPreTrainedModel):
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if getattr(self.config, "gradient_checkpointing", False):
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
if use_cache:
|
||||
raise ValueError(
|
||||
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`."
|
||||
logger.warn(
|
||||
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
|
||||
@@ -479,7 +479,14 @@ class RobertaEncoder(nn.Module):
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
if getattr(self.config, "gradient_checkpointing", False):
|
||||
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
if use_cache:
|
||||
logger.warn(
|
||||
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
|
||||
@@ -1137,6 +1137,15 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
||||
self.crit.cutoff_ends = [0] + new_cutoffs
|
||||
self.crit.n_token = new_num_tokens
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(mems: List[torch.Tensor], beam_idx: torch.Tensor) -> List[torch.Tensor]:
|
||||
"""
|
||||
This function is used to re-order the :obj:`mems` cache if :meth:`~transformers.PretrainedModel.beam_search` or
|
||||
:meth:`~transformers.PretrainedModel.beam_sample` is called. This is required to match :obj:`mems` with the
|
||||
correct beam_idx at every generation step.
|
||||
"""
|
||||
return [layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in mems]
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
||||
@@ -1462,6 +1462,15 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(mems: List[torch.Tensor], beam_idx: torch.Tensor) -> List[torch.Tensor]:
|
||||
"""
|
||||
This function is used to re-order the :obj:`mems` cache if :meth:`~transformers.PretrainedModel.beam_search` or
|
||||
:meth:`~transformers.PretrainedModel.beam_sample` is called. This is required to match :obj:`mems` with the
|
||||
correct beam_idx at every generation step.
|
||||
"""
|
||||
return [layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in mems]
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
|
||||
@@ -526,7 +526,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
if getattr(self.config, "gradient_checkpointing", False):
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
@@ -2182,7 +2182,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
|
||||
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if getattr(self.config, "gradient_checkpointing", False):
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
@@ -2374,11 +2374,11 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if getattr(self.config, "gradient_checkpointing", False):
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
|
||||
if use_cache:
|
||||
raise ValueError(
|
||||
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`."
|
||||
)
|
||||
logger.warn("`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`...")
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
|
||||
@@ -131,6 +131,7 @@ class GPT2ModelTester:
|
||||
n_ctx=self.max_position_embeddings,
|
||||
# type_vocab_size=self.type_vocab_size,
|
||||
# initializer_range=self.initializer_range,
|
||||
use_cache=not gradient_checkpointing,
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
|
||||
Reference in New Issue
Block a user