Update past_key_values in GPT-2 (#9596)
* Update past_key_values in gpt2 (#9391) * Update generation_utils, and rename some items * Update modeling_gpt2 to avoid an error in gradient_checkpointing * Remove 'reorder_cache' from util and add variations to XLNet, TransfoXL, GPT-2 * Change the location of '_reorder_cache' in modeling files * Add '_reorder_cache' in modeling_ctrl * Fix a bug of my last commit in CTRL * Add '_reorder_cache' to GPT2DoubleHeadsModel * Manage 'use_cache' in config of test_modeling_gpt2 * Clean up the doc string * Update src/transformers/models/gpt2/modeling_gpt2.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Fix the doc string (GPT-2, CTRL) * improve gradient_checkpointing_behavior Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -503,18 +503,10 @@ class GenerationMixin:
|
|||||||
|
|
||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
@staticmethod
|
def _reorder_cache(self, past, beam_idx):
|
||||||
def _reorder_cache(past: Tuple[torch.Tensor], beam_idx: torch.Tensor) -> Tuple[torch.Tensor]:
|
raise NotImplementedError(
|
||||||
"""
|
f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to enable beam search for {self.__class__}"
|
||||||
This function is used to re-order the :obj:`past_key_values` or :obj:`mems` cache if
|
)
|
||||||
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
|
|
||||||
called. This is required to match :obj:`past_key_values` or :obj:`mems` with the correct beam_idx at every
|
|
||||||
generation step.
|
|
||||||
|
|
||||||
For custom re-ordering of :obj:`past_key_values` or :obj:`mems`, the function should be implemented in
|
|
||||||
subclasses of :class:`~transformers.PreTrainedModel`.
|
|
||||||
"""
|
|
||||||
return tuple(layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in past)
|
|
||||||
|
|
||||||
def _get_logits_warper(
|
def _get_logits_warper(
|
||||||
self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None
|
self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None
|
||||||
|
|||||||
@@ -774,7 +774,7 @@ class BartEncoder(BartPretrainedModel):
|
|||||||
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
||||||
layer_outputs = (None, None)
|
layer_outputs = (None, None)
|
||||||
else:
|
else:
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
@@ -993,11 +993,13 @@ class BartDecoder(BartPretrainedModel):
|
|||||||
|
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
raise ValueError(
|
logger.warn(
|
||||||
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`."
|
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
|
||||||
)
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
|
|||||||
@@ -539,7 +539,14 @@ class BertEncoder(nn.Module):
|
|||||||
|
|
||||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
|
||||||
|
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
logger.warn(
|
||||||
|
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
|
|||||||
@@ -733,7 +733,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
|
|||||||
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
||||||
layer_outputs = (None, None)
|
layer_outputs = (None, None)
|
||||||
else:
|
else:
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
@@ -955,11 +955,13 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
|||||||
|
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
raise ValueError(
|
logger.warn(
|
||||||
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`."
|
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
|
||||||
)
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
|
|||||||
@@ -735,7 +735,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
|
|||||||
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
||||||
layer_outputs = (None, None)
|
layer_outputs = (None, None)
|
||||||
else:
|
else:
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
@@ -955,11 +955,13 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
|||||||
|
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
raise ValueError(
|
logger.warn(
|
||||||
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`."
|
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
|
||||||
)
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
|
|||||||
@@ -15,6 +15,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" PyTorch CTRL model."""
|
""" PyTorch CTRL model."""
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -262,7 +264,7 @@ CTRL_INPUTS_DOCSTRING = r"""
|
|||||||
details.
|
details.
|
||||||
|
|
||||||
`What are input IDs? <../glossary.html#input-ids>`__
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
past_key_values (:obj:`Tuple[Tuple[torch.FloatTensor]]` of length :obj:`config.n_layers`):
|
||||||
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see
|
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see
|
||||||
:obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
|
:obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
|
||||||
have their past given to this model should not be passed as input ids as they have already been computed.
|
have their past given to this model should not be passed as input ids as they have already been computed.
|
||||||
@@ -389,7 +391,7 @@ class CTRLModel(CTRLPreTrainedModel):
|
|||||||
|
|
||||||
if past_key_values is None:
|
if past_key_values is None:
|
||||||
past_length = 0
|
past_length = 0
|
||||||
past_key_values = [None] * len(self.h)
|
past_key_values = tuple([None] * len(self.h))
|
||||||
else:
|
else:
|
||||||
past_length = past_key_values[0][0].size(-2)
|
past_length = past_key_values[0][0].size(-2)
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
@@ -575,6 +577,18 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
|||||||
attentions=transformer_outputs.attentions,
|
attentions=transformer_outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
This function is used to re-order the :obj:`past_key_values` cache if
|
||||||
|
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
|
||||||
|
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
||||||
|
"""
|
||||||
|
return tuple(
|
||||||
|
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
||||||
|
for layer_past in past
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -536,7 +536,14 @@ class ElectraEncoder(nn.Module):
|
|||||||
|
|
||||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
|
||||||
|
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
logger.warn(
|
||||||
|
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -233,7 +233,7 @@ class Attention(nn.Module):
|
|||||||
value = torch.cat((past_value, value), dim=-2)
|
value = torch.cat((past_value, value), dim=-2)
|
||||||
|
|
||||||
if use_cache is True:
|
if use_cache is True:
|
||||||
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
|
present = (key.transpose(-2, -1), value) # transpose to have same shapes
|
||||||
else:
|
else:
|
||||||
present = None
|
present = None
|
||||||
|
|
||||||
@@ -370,9 +370,9 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
|
|||||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||||
mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
|
mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
|
||||||
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
|
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
|
||||||
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
||||||
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2,
|
Tuple of length :obj:`config.n_layers`, containing tuples of tensors of shape :obj:`(batch_size, num_heads,
|
||||||
batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
sequence_length, embed_size_per_head)`).
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
|
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
|
||||||
:obj:`past_key_values` input) to speed up sequential decoding.
|
:obj:`past_key_values` input) to speed up sequential decoding.
|
||||||
@@ -393,7 +393,7 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
|
|||||||
mc_loss: Optional[torch.FloatTensor] = None
|
mc_loss: Optional[torch.FloatTensor] = None
|
||||||
logits: torch.FloatTensor = None
|
logits: torch.FloatTensor = None
|
||||||
mc_logits: torch.FloatTensor = None
|
mc_logits: torch.FloatTensor = None
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
@@ -419,7 +419,7 @@ GPT2_INPUTS_DOCSTRING = r"""
|
|||||||
Args:
|
Args:
|
||||||
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
|
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
|
||||||
:obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else
|
:obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else
|
||||||
``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input
|
``past_key_values[0][0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input
|
||||||
sequence tokens in the vocabulary.
|
sequence tokens in the vocabulary.
|
||||||
|
|
||||||
If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be
|
If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be
|
||||||
@@ -430,7 +430,7 @@ GPT2_INPUTS_DOCSTRING = r"""
|
|||||||
details.
|
details.
|
||||||
|
|
||||||
`What are input IDs? <../glossary.html#input-ids>`__
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers`):
|
||||||
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
|
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
|
||||||
:obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
|
:obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
|
||||||
have their past given to this model should not be passed as ``input_ids`` as they have already been
|
have their past given to this model should not be passed as ``input_ids`` as they have already been
|
||||||
@@ -640,7 +640,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
|
|
||||||
if past_key_values is None:
|
if past_key_values is None:
|
||||||
past_length = 0
|
past_length = 0
|
||||||
past_key_values = [None] * len(self.h)
|
past_key_values = tuple([None] * len(self.h))
|
||||||
else:
|
else:
|
||||||
past_length = past_key_values[0][0].size(-2)
|
past_length = past_key_values[0][0].size(-2)
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
@@ -708,7 +708,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
torch.cuda.set_device(hidden_states.device)
|
torch.cuda.set_device(hidden_states.device)
|
||||||
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
layer_past = layer_past.to(hidden_states.device)
|
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
||||||
# Ensure that attention_mask is always on the same device as hidden_states
|
# Ensure that attention_mask is always on the same device as hidden_states
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask.to(hidden_states.device)
|
attention_mask = attention_mask.to(hidden_states.device)
|
||||||
@@ -717,19 +717,25 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
logger.warn(
|
||||||
|
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
# checkpointing only works with tuple returns, not with lists
|
# None for past_key_value
|
||||||
return tuple(output for output in module(*inputs, use_cache, output_attentions))
|
return module(*inputs, use_cache, output_attentions)
|
||||||
|
|
||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
outputs = torch.utils.checkpoint.checkpoint(
|
outputs = torch.utils.checkpoint.checkpoint(
|
||||||
create_custom_forward(block),
|
create_custom_forward(block),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
layer_past,
|
None,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
head_mask[i],
|
head_mask[i],
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
@@ -932,6 +938,18 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
cross_attentions=transformer_outputs.cross_attentions,
|
cross_attentions=transformer_outputs.cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
This function is used to re-order the :obj:`past_key_values` cache if
|
||||||
|
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
|
||||||
|
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
||||||
|
"""
|
||||||
|
return tuple(
|
||||||
|
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
||||||
|
for layer_past in past
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
@@ -1095,6 +1113,18 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|||||||
attentions=transformer_outputs.attentions,
|
attentions=transformer_outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
This function is used to re-order the :obj:`past_key_values` cache if
|
||||||
|
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
|
||||||
|
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
|
||||||
|
"""
|
||||||
|
return tuple(
|
||||||
|
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
||||||
|
for layer_past in past
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -466,7 +466,14 @@ class LayoutLMEncoder(nn.Module):
|
|||||||
|
|
||||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
|
||||||
|
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
logger.warn(
|
||||||
|
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
|
|||||||
@@ -1695,7 +1695,7 @@ class LEDEncoder(LEDPreTrainedModel):
|
|||||||
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
||||||
layer_outputs = (None, None, None)
|
layer_outputs = (None, None, None)
|
||||||
else:
|
else:
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
@@ -1920,11 +1920,13 @@ class LEDDecoder(LEDPreTrainedModel):
|
|||||||
|
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
raise ValueError(
|
logger.warn(
|
||||||
"When using `gradient_checkpointing`, make sure that `use_cache=False` and `config.use_cache=False`."
|
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
|
||||||
)
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
|
|||||||
@@ -1226,7 +1226,7 @@ class LongformerEncoder(nn.Module):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
|
|||||||
@@ -742,7 +742,7 @@ class MarianEncoder(MarianPreTrainedModel):
|
|||||||
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
||||||
layer_outputs = (None, None)
|
layer_outputs = (None, None)
|
||||||
else:
|
else:
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
@@ -958,11 +958,13 @@ class MarianDecoder(MarianPreTrainedModel):
|
|||||||
|
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
raise ValueError(
|
logger.warn(
|
||||||
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`."
|
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
|
||||||
)
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
|
|||||||
@@ -780,7 +780,7 @@ class MBartEncoder(MBartPreTrainedModel):
|
|||||||
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
||||||
layer_outputs = (None, None)
|
layer_outputs = (None, None)
|
||||||
else:
|
else:
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
@@ -1002,11 +1002,13 @@ class MBartDecoder(MBartPreTrainedModel):
|
|||||||
|
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
raise ValueError(
|
logger.warn(
|
||||||
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`."
|
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
|
||||||
)
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
|
|||||||
@@ -746,7 +746,7 @@ class PegasusEncoder(PegasusPreTrainedModel):
|
|||||||
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
||||||
layer_outputs = (None, None)
|
layer_outputs = (None, None)
|
||||||
else:
|
else:
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
@@ -966,11 +966,13 @@ class PegasusDecoder(PegasusPreTrainedModel):
|
|||||||
|
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
raise ValueError(
|
logger.warn(
|
||||||
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`."
|
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
|
||||||
)
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
|
|||||||
@@ -479,7 +479,14 @@ class RobertaEncoder(nn.Module):
|
|||||||
|
|
||||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
|
||||||
|
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
logger.warn(
|
||||||
|
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
|
|||||||
@@ -1137,6 +1137,15 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
|||||||
self.crit.cutoff_ends = [0] + new_cutoffs
|
self.crit.cutoff_ends = [0] + new_cutoffs
|
||||||
self.crit.n_token = new_num_tokens
|
self.crit.n_token = new_num_tokens
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _reorder_cache(mems: List[torch.Tensor], beam_idx: torch.Tensor) -> List[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
This function is used to re-order the :obj:`mems` cache if :meth:`~transformers.PretrainedModel.beam_search` or
|
||||||
|
:meth:`~transformers.PretrainedModel.beam_sample` is called. This is required to match :obj:`mems` with the
|
||||||
|
correct beam_idx at every generation step.
|
||||||
|
"""
|
||||||
|
return [layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in mems]
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1462,6 +1462,15 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
attentions=transformer_outputs.attentions,
|
attentions=transformer_outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _reorder_cache(mems: List[torch.Tensor], beam_idx: torch.Tensor) -> List[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
This function is used to re-order the :obj:`mems` cache if :meth:`~transformers.PretrainedModel.beam_search` or
|
||||||
|
:meth:`~transformers.PretrainedModel.beam_sample` is called. This is required to match :obj:`mems` with the
|
||||||
|
correct beam_idx at every generation step.
|
||||||
|
"""
|
||||||
|
return [layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in mems]
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -526,7 +526,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
|
|||||||
|
|
||||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
@@ -2182,7 +2182,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
|
|||||||
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
if self.training and (dropout_probability < self.layerdrop): # skip the layer
|
||||||
layer_outputs = (None, None)
|
layer_outputs = (None, None)
|
||||||
else:
|
else:
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
@@ -2374,11 +2374,11 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
|
|||||||
|
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
raise ValueError(
|
logger.warn("`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`...")
|
||||||
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`."
|
use_cache = False
|
||||||
)
|
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
|
|||||||
@@ -131,6 +131,7 @@ class GPT2ModelTester:
|
|||||||
n_ctx=self.max_position_embeddings,
|
n_ctx=self.max_position_embeddings,
|
||||||
# type_vocab_size=self.type_vocab_size,
|
# type_vocab_size=self.type_vocab_size,
|
||||||
# initializer_range=self.initializer_range,
|
# initializer_range=self.initializer_range,
|
||||||
|
use_cache=not gradient_checkpointing,
|
||||||
bos_token_id=self.bos_token_id,
|
bos_token_id=self.bos_token_id,
|
||||||
eos_token_id=self.eos_token_id,
|
eos_token_id=self.eos_token_id,
|
||||||
pad_token_id=self.pad_token_id,
|
pad_token_id=self.pad_token_id,
|
||||||
|
|||||||
Reference in New Issue
Block a user