GPT2Model StaticCache support (#35761)
* initial GPT2 changes * causal_mask support * return_legacy_cache * cleanup * fix1 * outputs shape fixes * gpt2 return fix * pkv, attn fixes * fix dual_head * is_causal arg fix * decision transformer updated * style fix * batch_size from inputs_embeds * DecisionTransformerModel fixes * cross-attn support + cache warning * x-attn @decision * EDCache proper init * simplified logic in `if use_cache:` for GPT2Model * @deprecate_kwarg for DecisionTr attn fwd * @deprecate_kwarg in gpt2 * deprecation version updated to 4.51 * kwargs in gradient_checkpointing_fn * rename next_cache to past_key_values * attention_mask prep * +cache_position in GPT2DoubleHeadsModel * undo kwargs in gradient checkpointing * moved up `if self.gradient_checkpointing` * consistency in decision_transformer * pastkv, cache_pos in grad_checkpt args * rm _reorder_cache * output_attentions streamlined * decision_transformer consistency * return_legacy_cache improved * ClvpForCausalLM used for legacy cache test now * is_causal fixed * attn_output cleanup * consistency @ decision_transformer * Updated deprecation notice version to 4.52 * upd deprecation * consistent legacy cache code in decision transformers\ * next_cache -> past_kv in decision_tr * cache support flags in decision_transf * rm legacy cache warning * consistency in cache init for decision transf * no Static Cache for Decision Transformer --------- Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
@@ -1589,7 +1589,6 @@ class ClvpForCausalLM(ClvpPreTrainedModel, GenerationMixin):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
|
||||
def _reorder_cache(
|
||||
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
|
||||
) -> Tuple[Tuple[torch.Tensor]]:
|
||||
|
||||
@@ -24,6 +24,7 @@ import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
||||
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
||||
@@ -34,6 +35,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from .configuration_decision_transformer import DecisionTransformerConfig
|
||||
|
||||
|
||||
@@ -125,7 +127,8 @@ def eager_attention_forward(module, query, key, value, attention_mask, head_mask
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@@ -257,19 +260,21 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
@deprecate_kwarg("layer_past", new_name="past_key_value", version="4.53.0", raise_if_both_names=True)
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
||||
if encoder_hidden_states is not None:
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
if is_cross_attention:
|
||||
if not hasattr(self, "q_attn"):
|
||||
raise ValueError(
|
||||
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
||||
@@ -289,17 +294,17 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
||||
key_states = key_states.view(shape_kv).transpose(1, 2)
|
||||
value_states = value_states.view(shape_kv).transpose(1, 2)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
key_states = torch.cat((past_key, key_states), dim=-2)
|
||||
value_states = torch.cat((past_value, value_states), dim=-2)
|
||||
|
||||
if use_cache is True:
|
||||
present = (key_states, value_states)
|
||||
if past_key_value is not None:
|
||||
if isinstance(past_key_value, EncoderDecoderCache):
|
||||
if is_cross_attention:
|
||||
past_key_value = past_key_value.cross_attention_cache
|
||||
else:
|
||||
present = None
|
||||
past_key_value = past_key_value.self_attention_cache
|
||||
cache_kwargs = {"cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs=cache_kwargs
|
||||
)
|
||||
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
|
||||
|
||||
using_eager = self.config._attn_implementation == "eager"
|
||||
@@ -338,11 +343,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
||||
attn_output = self.c_proj(attn_output)
|
||||
attn_output = self.resid_dropout(attn_output)
|
||||
|
||||
outputs = (attn_output, present)
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs # a, present, (attentions)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP with GPT2->DecisionTransformerGPT2
|
||||
@@ -383,10 +384,12 @@ class DecisionTransformerGPT2Block(nn.Module):
|
||||
|
||||
self.mlp = DecisionTransformerGPT2MLP(inner_dim, config)
|
||||
|
||||
@deprecate_kwarg("layer_past", new_name="past_key_value", version="4.53.0", raise_if_both_names=True)
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
@@ -396,16 +399,15 @@ class DecisionTransformerGPT2Block(nn.Module):
|
||||
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
attn_outputs = self.attn(
|
||||
attn_output, self_attn_weights = self.attn(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
past_key_value=past_key_value,
|
||||
cache_position=cache_position,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
||||
outputs = attn_outputs[1:]
|
||||
# residual connection
|
||||
hidden_states = attn_output + residual
|
||||
|
||||
@@ -418,18 +420,17 @@ class DecisionTransformerGPT2Block(nn.Module):
|
||||
)
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_cross_attn(hidden_states)
|
||||
cross_attn_outputs = self.crossattention(
|
||||
cross_attn_output, cross_attn_weights = self.crossattention(
|
||||
hidden_states,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attn_output = cross_attn_outputs[0]
|
||||
# residual connection
|
||||
hidden_states = residual + attn_output
|
||||
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
|
||||
hidden_states = residual + cross_attn_output
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_2(hidden_states)
|
||||
@@ -437,12 +438,13 @@ class DecisionTransformerGPT2Block(nn.Module):
|
||||
# residual connection
|
||||
hidden_states = residual + feed_forward_hidden_states
|
||||
|
||||
if use_cache:
|
||||
outputs = (hidden_states,) + outputs
|
||||
else:
|
||||
outputs = (hidden_states,) + outputs[1:]
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
if encoder_hidden_states is not None:
|
||||
outputs += (cross_attn_weights,)
|
||||
|
||||
return outputs # hidden_states, present, (attentions, cross_attentions)
|
||||
return outputs
|
||||
|
||||
|
||||
class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
|
||||
@@ -456,6 +458,8 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "transformer"
|
||||
is_parallelizable = True
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = False
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
@@ -521,6 +525,7 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
@@ -558,14 +563,31 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
|
||||
# based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder and similar addition in GPT2Model
|
||||
return_legacy_cache = False
|
||||
if use_cache:
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
return_legacy_cache = True
|
||||
past_key_values = DynamicCache()
|
||||
elif not isinstance(past_key_values, Cache):
|
||||
return_legacy_cache = True
|
||||
logger.warning_once(
|
||||
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. "
|
||||
"You should pass an instance of `Cache` instead, e.g. "
|
||||
"`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
|
||||
)
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
|
||||
if self.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache):
|
||||
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# Attention mask.
|
||||
if attention_mask is not None:
|
||||
@@ -624,17 +646,13 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
for i, block in enumerate(self.h):
|
||||
# Model parallel
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(hidden_states.device)
|
||||
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
||||
if layer_past is not None:
|
||||
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
||||
# Ensure that attention_mask is always on the same device as hidden_states
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
@@ -648,6 +666,7 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
None,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states,
|
||||
@@ -658,7 +677,8 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
past_key_value=past_key_values,
|
||||
cache_position=cache_position,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -668,13 +688,11 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
presents = presents + (outputs[1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
all_self_attentions = all_self_attentions + (outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
||||
all_cross_attentions = all_cross_attentions + (outputs[2],)
|
||||
|
||||
# Model Parallel: If it's the last layer for that device, put things on the next device
|
||||
if self.model_parallel:
|
||||
@@ -689,16 +707,23 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
past_key_values = past_key_values if use_cache else None
|
||||
if return_legacy_cache:
|
||||
past_key_values = (
|
||||
past_key_values.self_attention_cache.to_legacy_cache()
|
||||
if self.config.add_cross_attention
|
||||
else past_key_values.to_legacy_cache()
|
||||
)
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
|
||||
for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
past_key_values=past_key_values,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
|
||||
@@ -27,8 +27,9 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN, get_activation
|
||||
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_attention_mask_for_sdpa
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
@@ -46,6 +47,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from ...utils.model_parallel_utils import assert_device_map, get_device_map
|
||||
from .configuration_gpt2 import GPT2Config
|
||||
|
||||
@@ -136,7 +138,8 @@ def eager_attention_forward(module, query, key, value, attention_mask, head_mask
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
attn_weights = attn_weights + attention_mask
|
||||
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
@@ -267,19 +270,21 @@ class GPT2Attention(nn.Module):
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
@deprecate_kwarg("layer_past", new_name="past_key_value", version="4.53.0", raise_if_both_names=True)
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
||||
if encoder_hidden_states is not None:
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
if is_cross_attention:
|
||||
if not hasattr(self, "q_attn"):
|
||||
raise ValueError(
|
||||
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
||||
@@ -299,17 +304,17 @@ class GPT2Attention(nn.Module):
|
||||
key_states = key_states.view(shape_kv).transpose(1, 2)
|
||||
value_states = value_states.view(shape_kv).transpose(1, 2)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
key_states = torch.cat((past_key, key_states), dim=-2)
|
||||
value_states = torch.cat((past_value, value_states), dim=-2)
|
||||
|
||||
if use_cache is True:
|
||||
present = (key_states, value_states)
|
||||
if past_key_value is not None:
|
||||
if isinstance(past_key_value, EncoderDecoderCache):
|
||||
if is_cross_attention:
|
||||
past_key_value = past_key_value.cross_attention_cache
|
||||
else:
|
||||
present = None
|
||||
past_key_value = past_key_value.self_attention_cache
|
||||
cache_kwargs = {"cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs=cache_kwargs
|
||||
)
|
||||
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
|
||||
|
||||
using_eager = self.config._attn_implementation == "eager"
|
||||
@@ -348,11 +353,7 @@ class GPT2Attention(nn.Module):
|
||||
attn_output = self.c_proj(attn_output)
|
||||
attn_output = self.resid_dropout(attn_output)
|
||||
|
||||
outputs = (attn_output, present)
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs # a, present, (attentions)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class GPT2MLP(nn.Module):
|
||||
@@ -388,10 +389,12 @@ class GPT2Block(nn.Module):
|
||||
|
||||
self.mlp = GPT2MLP(inner_dim, config)
|
||||
|
||||
@deprecate_kwarg("layer_past", new_name="past_key_value", version="4.53.0", raise_if_both_names=True)
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
@@ -401,16 +404,15 @@ class GPT2Block(nn.Module):
|
||||
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_1(hidden_states)
|
||||
attn_outputs = self.attn(
|
||||
attn_output, self_attn_weights = self.attn(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
past_key_value=past_key_value,
|
||||
cache_position=cache_position,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
||||
outputs = attn_outputs[1:]
|
||||
# residual connection
|
||||
hidden_states = attn_output + residual
|
||||
|
||||
@@ -423,18 +425,17 @@ class GPT2Block(nn.Module):
|
||||
)
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_cross_attn(hidden_states)
|
||||
cross_attn_outputs = self.crossattention(
|
||||
cross_attn_output, cross_attn_weights = self.crossattention(
|
||||
hidden_states,
|
||||
past_key_value=past_key_value,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attn_output = cross_attn_outputs[0]
|
||||
# residual connection
|
||||
hidden_states = residual + attn_output
|
||||
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
|
||||
hidden_states = residual + cross_attn_output
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.ln_2(hidden_states)
|
||||
@@ -442,12 +443,13 @@ class GPT2Block(nn.Module):
|
||||
# residual connection
|
||||
hidden_states = residual + feed_forward_hidden_states
|
||||
|
||||
if use_cache:
|
||||
outputs = (hidden_states,) + outputs
|
||||
else:
|
||||
outputs = (hidden_states,) + outputs[1:]
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
if encoder_hidden_states is not None:
|
||||
outputs += (cross_attn_weights,)
|
||||
|
||||
return outputs # hidden_states, present, (attentions, cross_attentions)
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->GPT2
|
||||
@@ -565,6 +567,8 @@ class GPT2PreTrainedModel(PreTrainedModel):
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
_supports_static_cache = True
|
||||
|
||||
def __init__(self, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
@@ -669,10 +673,24 @@ GPT2_INPUTS_DOCSTRING = r"""
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
|
||||
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
|
||||
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
|
||||
their past given to this model should not be passed as `input_ids` as they have already been computed.
|
||||
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
||||
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
||||
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
||||
|
||||
Two formats are allowed:
|
||||
- a [`~cache_utils.Cache`] instance, see our
|
||||
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
|
||||
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
||||
cache format.
|
||||
|
||||
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
||||
legacy cache format will be returned.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
||||
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
||||
of shape `(batch_size, sequence_length)`.
|
||||
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
@@ -721,6 +739,10 @@ GPT2_INPUTS_DOCSTRING = r"""
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||
the complete sequence length.
|
||||
"""
|
||||
PARALLELIZE_DOCSTRING = r"""
|
||||
This is an experimental feature and is a subject to change at a moment's notice.
|
||||
@@ -868,7 +890,8 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
past_key_values: Optional[Union[Tuple[Tuple[torch.Tensor]], Cache]] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
@@ -906,51 +929,56 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder
|
||||
return_legacy_cache = False
|
||||
if use_cache:
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
return_legacy_cache = True
|
||||
past_key_values = DynamicCache()
|
||||
elif not isinstance(past_key_values, Cache):
|
||||
return_legacy_cache = True
|
||||
logger.warning_once(
|
||||
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. "
|
||||
"You should pass an instance of `Cache` instead, e.g. "
|
||||
"`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
|
||||
)
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
|
||||
if self.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache):
|
||||
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)
|
||||
|
||||
# Attention mask.
|
||||
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
|
||||
attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None
|
||||
if self._attn_implementation == "flash_attention_2":
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif _use_sdpa:
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask=attention_mask,
|
||||
input_shape=(batch_size, input_shape[-1]),
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_length,
|
||||
# ._update_causal_mask() and ._prepare_4d_causal_attention_mask_with_cache_position() copied from LlamaModel
|
||||
if attention_mask is not None and attention_mask.ndim < 4:
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
else:
|
||||
if attention_mask is not None:
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and the dtype's smallest value for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
|
||||
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
@@ -979,25 +1007,13 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
|
||||
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for i in range(len(self.h)):
|
||||
block, layer_past = self.h[i], past_key_values[i]
|
||||
for i, block in enumerate(self.h):
|
||||
# Model parallel
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(hidden_states.device)
|
||||
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
||||
if layer_past is not None:
|
||||
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
||||
# Ensure that attention_mask is always on the same device as hidden_states
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
@@ -1010,8 +1026,9 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
past_key_values,
|
||||
cache_position,
|
||||
causal_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
@@ -1021,8 +1038,9 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
past_key_value=past_key_values,
|
||||
cache_position=cache_position,
|
||||
attention_mask=causal_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
@@ -1031,13 +1049,11 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
presents = presents + (outputs[1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
all_self_attentions = all_self_attentions + (outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
||||
all_cross_attentions = all_cross_attentions + (outputs[2],)
|
||||
|
||||
# Model Parallel: If it's the last layer for that device, put things on the next device
|
||||
if self.model_parallel:
|
||||
@@ -1052,21 +1068,149 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
past_key_values = past_key_values if use_cache else None
|
||||
if return_legacy_cache:
|
||||
past_key_values = (
|
||||
past_key_values.self_attention_cache.to_legacy_cache()
|
||||
if self.config.add_cross_attention
|
||||
else past_key_values.to_legacy_cache()
|
||||
)
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
|
||||
for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions]
|
||||
if v is not None
|
||||
)
|
||||
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
past_key_values=past_key_values,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype, device = input_tensor.dtype, input_tensor.device
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type == "cuda"
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
device (`torch.device`):
|
||||
The device to plcae the 4D attention mask on.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
@@ -1137,6 +1281,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin):
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
@@ -1163,6 +1308,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin):
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
@@ -1206,20 +1352,6 @@ class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin):
|
||||
cross_attentions=transformer_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(
|
||||
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
|
||||
) -> Tuple[Tuple[torch.Tensor]]:
|
||||
"""
|
||||
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
||||
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
||||
beam_idx at every generation step.
|
||||
"""
|
||||
return tuple(
|
||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
||||
for layer_past in past_key_values
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
@@ -1292,6 +1424,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin):
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
@@ -1350,6 +1483,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin):
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
cache_position=cache_position,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
|
||||
@@ -40,9 +40,9 @@ if is_torch_available():
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
ClvpForCausalLM,
|
||||
DynamicCache,
|
||||
GenerationConfig,
|
||||
GPT2LMHeadModel,
|
||||
LlamaConfig,
|
||||
SinkCache,
|
||||
StaticCache,
|
||||
@@ -103,7 +103,7 @@ class CacheTest(unittest.TestCase):
|
||||
|
||||
def test_reorder_cache_retrocompatibility(self):
|
||||
"""Tests that Cache.reorder_cache is retrocompatible with the legacy code path"""
|
||||
legacy_reorder_fn = GPT2LMHeadModel._reorder_cache # An example of a legacy `_reorder_cache` function
|
||||
legacy_reorder_fn = ClvpForCausalLM._reorder_cache # An example of a legacy `_reorder_cache` function
|
||||
|
||||
legacy_cache = ()
|
||||
new_cache = DynamicCache()
|
||||
|
||||
Reference in New Issue
Block a user