Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
092f1fdaa4 | ||
|
|
bf5163f413 | ||
|
|
6c45f0f631 | ||
|
|
bfefb8ef8b | ||
|
|
20164cc2c6 | ||
|
|
d5ec19460f | ||
|
|
bf99e862ff | ||
|
|
6d023503e9 | ||
|
|
4f8689e8ea | ||
|
|
a0857740c0 | ||
|
|
2f54e0b358 |
2
setup.py
2
setup.py
@@ -428,7 +428,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.38.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.38.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
|
||||
author_email="transformers@huggingface.co",
|
||||
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.38.0"
|
||||
__version__ = "4.38.2"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -2514,9 +2514,10 @@ class ConditionalDetrLoss(nn.Module):
|
||||
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
||||
|
||||
world_size = 1
|
||||
if PartialState._shared_state != {}:
|
||||
num_boxes = reduce(num_boxes)
|
||||
world_size = PartialState().num_processes
|
||||
if is_accelerate_available():
|
||||
if PartialState._shared_state != {}:
|
||||
num_boxes = reduce(num_boxes)
|
||||
world_size = PartialState().num_processes
|
||||
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
|
||||
|
||||
# Compute all the requested losses
|
||||
|
||||
@@ -2282,9 +2282,10 @@ class DeformableDetrLoss(nn.Module):
|
||||
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
||||
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
||||
world_size = 1
|
||||
if PartialState._shared_state != {}:
|
||||
num_boxes = reduce(num_boxes)
|
||||
world_size = PartialState().num_processes
|
||||
if is_accelerate_available():
|
||||
if PartialState._shared_state != {}:
|
||||
num_boxes = reduce(num_boxes)
|
||||
world_size = PartialState().num_processes
|
||||
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
|
||||
|
||||
# Compute all the requested losses
|
||||
|
||||
@@ -2345,9 +2345,10 @@ class DetaLoss(nn.Module):
|
||||
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
||||
# Check that we have initialized the distributed state
|
||||
world_size = 1
|
||||
if PartialState._shared_state != {}:
|
||||
num_boxes = reduce(num_boxes)
|
||||
world_size = PartialState().num_processes
|
||||
if is_accelerate_available():
|
||||
if PartialState._shared_state != {}:
|
||||
num_boxes = reduce(num_boxes)
|
||||
world_size = PartialState().num_processes
|
||||
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
|
||||
|
||||
# Compute all the requested losses
|
||||
|
||||
@@ -2210,9 +2210,10 @@ class DetrLoss(nn.Module):
|
||||
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
||||
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
||||
world_size = 1
|
||||
if PartialState._shared_state != {}:
|
||||
num_boxes = reduce(num_boxes)
|
||||
world_size = PartialState().num_processes
|
||||
if is_accelerate_available():
|
||||
if PartialState._shared_state != {}:
|
||||
num_boxes = reduce(num_boxes)
|
||||
world_size = PartialState().num_processes
|
||||
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
|
||||
|
||||
# Compute all the requested losses
|
||||
|
||||
@@ -101,18 +101,25 @@ class GemmaRotaryEmbedding(nn.Module):
|
||||
self.base = base
|
||||
self.register_buffer("inv_freq", None, persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if self.inv_freq is None:
|
||||
self.inv_freq = 1.0 / (
|
||||
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
|
||||
)
|
||||
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype)
|
||||
# Force float32 since bfloat16 loses precision on long contexts
|
||||
# See https://github.com/huggingface/transformers/pull/29285
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
@@ -124,7 +131,7 @@ def rotate_half(x):
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
@@ -132,9 +139,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`):
|
||||
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
||||
used to pass offsetted position ids when working with a KV-cache.
|
||||
position_ids (`torch.Tensor`, *optional*):
|
||||
Deprecated and unused.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
@@ -277,7 +283,7 @@ class GemmaAttention(nn.Module):
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = attn_output.view(bsz, q_len, -1)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
@@ -811,8 +817,11 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# register a causal mask to separate causal and padding mask creation. Merging happends in the attention class
|
||||
causal_mask = torch.full((config.max_position_embeddings, config.max_position_embeddings), fill_value=1)
|
||||
# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
|
||||
# NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
|
||||
causal_mask = torch.full(
|
||||
(config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool
|
||||
)
|
||||
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@@ -955,31 +964,28 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
|
||||
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
|
||||
|
||||
if hasattr(self, "causal_mask"): # we use the current dtype to avoid any overflows
|
||||
causal_mask = (
|
||||
self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min
|
||||
)
|
||||
else:
|
||||
mask = torch.full(
|
||||
(self.config.max_position_embeddings, self.config.max_position_embeddings),
|
||||
fill_value=torch.finfo(dtype).min,
|
||||
)
|
||||
causal_mask = torch.triu(mask, diagonal=1)
|
||||
# We use the current dtype to avoid any overflows
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype
|
||||
|
||||
causal_mask = causal_mask.to(dtype=dtype, device=device)
|
||||
if attention_mask is not None and attention_mask.dim() == 2:
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
||||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
|
||||
padding_mask, torch.finfo(dtype).min
|
||||
)
|
||||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
||||
|
||||
if self.config._attn_implementation == "sdpa":
|
||||
is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy)
|
||||
if not is_tracing and attention_mask is not None and torch.any(attention_mask != 1):
|
||||
causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to(
|
||||
dtype
|
||||
)
|
||||
if self.config._attn_implementation == "sdpa" and attention_mask is not None:
|
||||
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
|
||||
is_tracing = (
|
||||
torch.jit.is_tracing()
|
||||
or isinstance(input_tensor, torch.fx.Proxy)
|
||||
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
||||
)
|
||||
if not is_tracing and torch.any(attention_mask != 1):
|
||||
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)).to(dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@@ -1079,7 +1085,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
logits = logits.float()
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
@@ -1146,29 +1152,32 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
|
||||
if self.generation_config.cache_implementation == "static":
|
||||
# generation with static cache
|
||||
past_length = past_key_value.get_seq_length()
|
||||
cache_position = kwargs.get("cache_position", None)
|
||||
if cache_position is None:
|
||||
past_length = 0
|
||||
else:
|
||||
past_length = cache_position[-1] + 1
|
||||
input_ids = input_ids[:, past_length:]
|
||||
position_ids = position_ids[:, past_length:]
|
||||
|
||||
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
|
||||
# same goes for position ids. Could also help with continued generation.
|
||||
cache_position = kwargs.get("cache_position", None)
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(
|
||||
past_length, past_length + position_ids.shape[-1], device=position_ids.device
|
||||
)
|
||||
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
||||
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
|
||||
# TODO: use `next_tokens` directly instead.
|
||||
model_inputs = {"input_ids": input_ids.contiguous()}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"position_ids": position_ids.contiguous(),
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
|
||||
@@ -563,10 +563,11 @@ class GPTNeoXRotaryEmbedding(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__
|
||||
# TODO @gante bring compatibility back
|
||||
class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
|
||||
"""GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
@@ -586,7 +587,8 @@ class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
|
||||
class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
|
||||
"""GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding.__init__
|
||||
# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding.__init__
|
||||
# TODO @gante no longer copied from
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
@@ -92,54 +92,63 @@ ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
|
||||
|
||||
|
||||
class LlamaRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
super().__init__()
|
||||
self.scaling_factor = scaling_factor
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
# For BC we register cos and sin cached
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
||||
t = t / self.scaling_factor
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
|
||||
self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)
|
||||
|
||||
@property
|
||||
def sin_cached(self):
|
||||
logger.warning_once(
|
||||
"The sin_cached attribute will be removed in 4.40. Bear in mind that its contents changed in v4.38. Use "
|
||||
"the forward method of RoPE from now on instead."
|
||||
"The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
|
||||
"the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
|
||||
)
|
||||
return self._sin_cached
|
||||
|
||||
@property
|
||||
def cos_cached(self):
|
||||
logger.warning_once(
|
||||
"The cos_cached attribute will be removed in 4.40. Bear in mind that its contents changed in v4.38. Use "
|
||||
"the forward method of RoPE from now on instead."
|
||||
"The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
|
||||
"the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
|
||||
)
|
||||
return self._cos_cached
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x, position_ids, seq_len=None):
|
||||
if seq_len is not None:
|
||||
logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.40.")
|
||||
logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.39.")
|
||||
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos().to(dtype=x.dtype)
|
||||
sin = emb.sin().to(dtype=x.dtype)
|
||||
# backwards compatibility
|
||||
self._cos_cached = cos
|
||||
self._sin_cached = sin
|
||||
return cos, sin
|
||||
# Force float32 since bfloat16 loses precision on long contexts
|
||||
# See https://github.com/huggingface/transformers/pull/29285
|
||||
device_type = x.device.type
|
||||
device_type = device_type if isinstance(device_type, str) else "cpu"
|
||||
with torch.autocast(device_type=device_type, enabled=False):
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
def forward(self, x, position_ids, seq_len=None):
|
||||
# difference to the original RoPE: a scaling factor is aplied to the position ids
|
||||
position_ids = position_ids.float() / self.scaling_factor
|
||||
@@ -150,10 +159,6 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
|
||||
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
def forward(self, x, position_ids, seq_len=None):
|
||||
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
|
||||
seq_len = torch.max(position_ids) + 1
|
||||
@@ -367,6 +372,7 @@ class LlamaAttention(nn.Module):
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask
|
||||
if cache_position is not None:
|
||||
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
@@ -810,7 +816,9 @@ class LlamaPreTrainedModel(PreTrainedModel):
|
||||
)
|
||||
|
||||
if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
|
||||
causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=1, device=self.device)
|
||||
causal_mask = torch.full(
|
||||
(max_cache_len, max_cache_len), fill_value=True, device=self.device, dtype=torch.bool
|
||||
)
|
||||
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
|
||||
|
||||
for layer in self.model.layers:
|
||||
@@ -918,8 +926,11 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# register a causal mask to separate causal and padding mask creation. Merging happends in the attention class
|
||||
causal_mask = torch.full((config.max_position_embeddings, config.max_position_embeddings), fill_value=1)
|
||||
# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
|
||||
# NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
|
||||
causal_mask = torch.full(
|
||||
(config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool
|
||||
)
|
||||
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@@ -1058,31 +1069,28 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
|
||||
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
|
||||
|
||||
if hasattr(self, "causal_mask"): # we use the current dtype to avoid any overflows
|
||||
causal_mask = (
|
||||
self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min
|
||||
)
|
||||
else:
|
||||
mask = torch.full(
|
||||
(self.config.max_position_embeddings, self.config.max_position_embeddings),
|
||||
fill_value=torch.finfo(dtype).min,
|
||||
)
|
||||
causal_mask = torch.triu(mask, diagonal=1)
|
||||
# We use the current dtype to avoid any overflows
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype
|
||||
|
||||
causal_mask = causal_mask.to(dtype=dtype, device=device)
|
||||
if attention_mask is not None and attention_mask.dim() == 2:
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
||||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
|
||||
padding_mask, torch.finfo(dtype).min
|
||||
)
|
||||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
||||
|
||||
if self.config._attn_implementation == "sdpa":
|
||||
is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy)
|
||||
if not is_tracing and attention_mask is not None and torch.any(attention_mask != 1):
|
||||
causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to(
|
||||
dtype
|
||||
)
|
||||
if self.config._attn_implementation == "sdpa" and attention_mask is not None:
|
||||
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
|
||||
is_tracing = (
|
||||
torch.jit.is_tracing()
|
||||
or isinstance(input_tensor, torch.fx.Proxy)
|
||||
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
||||
)
|
||||
if not is_tracing and torch.any(attention_mask != 1):
|
||||
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)).to(dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@@ -1253,29 +1261,32 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
|
||||
if self.generation_config.cache_implementation == "static":
|
||||
# generation with static cache
|
||||
past_length = past_key_value.get_seq_length()
|
||||
cache_position = kwargs.get("cache_position", None)
|
||||
if cache_position is None:
|
||||
past_length = 0
|
||||
else:
|
||||
past_length = cache_position[-1] + 1
|
||||
input_ids = input_ids[:, past_length:]
|
||||
position_ids = position_ids[:, past_length:]
|
||||
|
||||
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
|
||||
# same goes for position ids. Could also help with continued generation.
|
||||
cache_position = kwargs.get("cache_position", None)
|
||||
if cache_position is None:
|
||||
cache_position = torch.arange(
|
||||
past_length, past_length + position_ids.shape[-1], device=position_ids.device
|
||||
)
|
||||
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
||||
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
|
||||
# TODO: use `next_tokens` directly instead.
|
||||
model_inputs = {"input_ids": input_ids.contiguous()}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"position_ids": position_ids.contiguous(),
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
|
||||
@@ -243,7 +243,7 @@ class LlamaTokenizer(PreTrainedTokenizer):
|
||||
return vocab
|
||||
|
||||
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize
|
||||
def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> List[str]:
|
||||
def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
|
||||
"""
|
||||
Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the
|
||||
first token is special.
|
||||
@@ -255,7 +255,7 @@ class LlamaTokenizer(PreTrainedTokenizer):
|
||||
if self.add_prefix_space:
|
||||
text = SPIECE_UNDERLINE + text
|
||||
|
||||
tokens = super().tokenize(text, add_special_tokens=add_special_tokens, **kwargs)
|
||||
tokens = super().tokenize(text, **kwargs)
|
||||
|
||||
if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
|
||||
tokens = tokens[1:]
|
||||
|
||||
@@ -791,14 +791,15 @@ class Mask2FormerLoss(nn.Module):
|
||||
Computes the average number of target masks across the batch, for normalization purposes.
|
||||
"""
|
||||
num_masks = sum([len(classes) for classes in class_labels])
|
||||
num_masks_pt = torch.as_tensor(num_masks, dtype=torch.float, device=device)
|
||||
num_masks = torch.as_tensor(num_masks, dtype=torch.float, device=device)
|
||||
world_size = 1
|
||||
if PartialState._shared_state != {}:
|
||||
num_masks_pt = reduce(num_masks_pt)
|
||||
world_size = PartialState().num_processes
|
||||
if is_accelerate_available():
|
||||
if PartialState._shared_state != {}:
|
||||
num_masks = reduce(num_masks)
|
||||
world_size = PartialState().num_processes
|
||||
|
||||
num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1)
|
||||
return num_masks_pt
|
||||
num_masks = torch.clamp(num_masks / world_size, min=1)
|
||||
return num_masks
|
||||
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
|
||||
|
||||
@@ -1198,14 +1198,15 @@ class MaskFormerLoss(nn.Module):
|
||||
Computes the average number of target masks across the batch, for normalization purposes.
|
||||
"""
|
||||
num_masks = sum([len(classes) for classes in class_labels])
|
||||
num_masks_pt = torch.as_tensor(num_masks, dtype=torch.float, device=device)
|
||||
num_masks = torch.as_tensor(num_masks, dtype=torch.float, device=device)
|
||||
world_size = 1
|
||||
if PartialState._shared_state != {}:
|
||||
num_masks_pt = reduce(num_masks_pt)
|
||||
world_size = PartialState().num_processes
|
||||
if is_accelerate_available():
|
||||
if PartialState._shared_state != {}:
|
||||
num_masks = reduce(num_masks)
|
||||
world_size = PartialState().num_processes
|
||||
|
||||
num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1)
|
||||
return num_masks_pt
|
||||
num_masks = torch.clamp(num_masks / world_size, min=1)
|
||||
return num_masks
|
||||
|
||||
|
||||
class MaskFormerFPNConvLayer(nn.Module):
|
||||
|
||||
@@ -727,14 +727,15 @@ class OneFormerLoss(nn.Module):
|
||||
Computes the average number of target masks across the batch, for normalization purposes.
|
||||
"""
|
||||
num_masks = sum([len(classes) for classes in class_labels])
|
||||
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device)
|
||||
num_masks = torch.as_tensor([num_masks], dtype=torch.float, device=device)
|
||||
world_size = 1
|
||||
if PartialState._shared_state != {}:
|
||||
num_masks_pt = reduce(num_masks_pt)
|
||||
world_size = PartialState().num_processes
|
||||
if is_accelerate_available():
|
||||
if PartialState._shared_state != {}:
|
||||
num_masks = reduce(num_masks)
|
||||
world_size = PartialState().num_processes
|
||||
|
||||
num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1)
|
||||
return num_masks_pt
|
||||
num_masks = torch.clamp(num_masks / world_size, min=1)
|
||||
return num_masks
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -447,7 +447,7 @@ class SeamlessM4TTokenizer(PreTrainedTokenizer):
|
||||
return tokenizer
|
||||
|
||||
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize
|
||||
def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> List[str]:
|
||||
def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
|
||||
"""
|
||||
Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the
|
||||
first token is special.
|
||||
@@ -459,7 +459,7 @@ class SeamlessM4TTokenizer(PreTrainedTokenizer):
|
||||
if self.add_prefix_space:
|
||||
text = SPIECE_UNDERLINE + text
|
||||
|
||||
tokens = super().tokenize(text, add_special_tokens=add_special_tokens, **kwargs)
|
||||
tokens = super().tokenize(text, **kwargs)
|
||||
|
||||
if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
|
||||
tokens = tokens[1:]
|
||||
|
||||
@@ -377,7 +377,7 @@ class T5Tokenizer(PreTrainedTokenizer):
|
||||
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||
self.sp_model.Load(self.vocab_file)
|
||||
|
||||
def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> List[str]:
|
||||
def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
|
||||
"""
|
||||
Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the
|
||||
first token is special.
|
||||
@@ -389,7 +389,7 @@ class T5Tokenizer(PreTrainedTokenizer):
|
||||
if self.add_prefix_space:
|
||||
text = SPIECE_UNDERLINE + text
|
||||
|
||||
tokens = super().tokenize(text, add_special_tokens=add_special_tokens, **kwargs)
|
||||
tokens = super().tokenize(text, **kwargs)
|
||||
|
||||
if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
|
||||
tokens = tokens[1:]
|
||||
|
||||
@@ -1757,9 +1757,10 @@ class TableTransformerLoss(nn.Module):
|
||||
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
||||
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
||||
world_size = 1
|
||||
if PartialState._shared_state != {}:
|
||||
num_boxes = reduce(num_boxes)
|
||||
world_size = PartialState().num_processes
|
||||
if is_accelerate_available():
|
||||
if PartialState._shared_state != {}:
|
||||
num_boxes = reduce(num_boxes)
|
||||
world_size = PartialState().num_processes
|
||||
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
|
||||
|
||||
# Compute all the requested losses
|
||||
|
||||
@@ -1079,9 +1079,10 @@ class YolosLoss(nn.Module):
|
||||
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
||||
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
||||
world_size = 1
|
||||
if PartialState._shared_state != {}:
|
||||
num_boxes = reduce(num_boxes)
|
||||
world_size = PartialState().num_processes
|
||||
if is_accelerate_available():
|
||||
if PartialState._shared_state != {}:
|
||||
num_boxes = reduce(num_boxes)
|
||||
world_size = PartialState().num_processes
|
||||
num_boxes = torch.clamp(num_boxes / world_size, min=1).item()
|
||||
|
||||
# Compute all the requested losses
|
||||
|
||||
@@ -26,6 +26,7 @@ from transformers.testing_utils import (
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_sdpa,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -460,6 +461,71 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
def test_flash_attn_2_inference_padding_right(self):
|
||||
self.skipTest("Gemma flash attention does not support right padding")
|
||||
|
||||
@require_torch_sdpa
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
def test_sdpa_equivalence(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_sdpa:
|
||||
return
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_sdpa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa"
|
||||
)
|
||||
model_sdpa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager")
|
||||
model.to(torch_device)
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
dummy_input = dummy_input.to(torch_device)
|
||||
outputs = model(dummy_input, output_hidden_states=True)
|
||||
outputs_sdpa = model_sdpa(dummy_input, output_hidden_states=True)
|
||||
|
||||
logits = outputs.hidden_states[-1]
|
||||
logits_sdpa = outputs_sdpa.hidden_states[-1]
|
||||
|
||||
# gemma sdpa needs a high tolerance
|
||||
assert torch.allclose(logits_sdpa, logits, atol=3e-3)
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@pytest.mark.flash_attn_test
|
||||
@slow
|
||||
def test_flash_attn_2_equivalence(self):
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class._supports_flash_attn_2:
|
||||
return
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_fa = model_class.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2"
|
||||
)
|
||||
model_fa.to(torch_device)
|
||||
|
||||
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager")
|
||||
model.to(torch_device)
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
dummy_input = dummy_input.to(torch_device)
|
||||
outputs = model(dummy_input, output_hidden_states=True)
|
||||
outputs_fa = model_fa(dummy_input, output_hidden_states=True)
|
||||
|
||||
logits = outputs.hidden_states[-1]
|
||||
logits_fa = outputs_fa.hidden_states[-1]
|
||||
|
||||
# gemma flash attention 2 needs a high tolerance
|
||||
assert torch.allclose(logits_fa, logits, atol=3e-3)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
@@ -542,6 +608,69 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
def test_model_2b_eager(self):
|
||||
model_id = "google/gemma-2b"
|
||||
EXPECTED_TEXTS = [
|
||||
"Hello I am doing a project on the 1990s and I am looking for some information on the ",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||
]
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="eager"
|
||||
)
|
||||
model.to(torch_device)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@require_torch_sdpa
|
||||
def test_model_2b_sdpa(self):
|
||||
model_id = "google/gemma-2b"
|
||||
EXPECTED_TEXTS = [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Khichdi",
|
||||
]
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="sdpa"
|
||||
)
|
||||
model.to(torch_device)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@pytest.mark.flash_attn_test
|
||||
@require_flash_attn
|
||||
def test_model_2b_flash_attn(self):
|
||||
model_id = "google/gemma-2b"
|
||||
EXPECTED_TEXTS = [
|
||||
"Hello I am doing a project on the 1990s and I need to know what the most popular music",
|
||||
"Hi today I am going to share with you a very easy and simple recipe of <strong><em>Kaju Kat",
|
||||
]
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
||||
)
|
||||
model.to(torch_device)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@require_bitsandbytes
|
||||
def test_model_2b_4bit(self):
|
||||
model_id = "google/gemma-2b"
|
||||
|
||||
@@ -20,8 +20,9 @@ import unittest
|
||||
import pytest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import LlamaConfig, is_torch_available, set_seed
|
||||
from transformers import LlamaConfig, StaticCache, is_torch_available, logging, set_seed
|
||||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
require_bitsandbytes,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
@@ -595,6 +596,55 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_compile_static_cache(self):
|
||||
NUM_TOKENS_TO_GENERATE = 40
|
||||
EXPECTED_TEXT_COMPLETION = [
|
||||
"Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.",
|
||||
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
|
||||
]
|
||||
prompts = [
|
||||
"Simply put, the theory of relativity states that ",
|
||||
"My favorite all time favorite condiment is ketchup.",
|
||||
]
|
||||
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
|
||||
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="sequential")
|
||||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
||||
|
||||
def decode_one_tokens(model, cur_token, input_pos, cache_position):
|
||||
logits = model(
|
||||
cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True
|
||||
)[0]
|
||||
new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
|
||||
return new_token
|
||||
|
||||
batch_size, seq_length = inputs["input_ids"].shape
|
||||
with torch.no_grad():
|
||||
model._setup_cache(StaticCache, 2, max_cache_len=4096)
|
||||
cache_position = torch.arange(seq_length, device=torch_device)
|
||||
generated_ids = torch.zeros(
|
||||
batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device
|
||||
)
|
||||
generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int)
|
||||
|
||||
logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0]
|
||||
next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
|
||||
generated_ids[:, seq_length] = next_token[:, 0]
|
||||
|
||||
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
|
||||
cache_position = torch.tensor([seq_length + 1], device=torch_device)
|
||||
for _ in range(1, NUM_TOKENS_TO_GENERATE):
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
||||
with CaptureLogger(logging.get_logger(__name__)) as cl:
|
||||
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position)
|
||||
self.assertNotIn("skipping cudagraphs due to", cl.out)
|
||||
generated_ids[:, cache_position] = next_token.int()
|
||||
cache_position += 1
|
||||
|
||||
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
||||
|
||||
|
||||
@require_torch
|
||||
class CodeLlamaIntegrationTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user