Compare commits

...

9 Commits

Author SHA1 Message Date
Arthur Zucker
47c29ccfaf Patch release v4.43.3
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
2024-07-26 17:09:09 +02:00
João Nadkarni
54bc29c1ba don't log base model architecture in wandb if log model is false (#32143)
* don't log base model architecture in wandb is log model is false

* Update src/transformers/integrations/integration_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* convert log model setting into an enum

* fix formatting

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
2024-07-26 17:01:59 +02:00
Kashif Rasul
cc75146d0e [BigBird Pegasus] set _supports_param_buffer_assignment to False (#32222)
set _supports_param_buffer_assignment to False
2024-07-26 17:01:59 +02:00
Sanchit Gandhi
cd06184cc4 [whisper] fix short-form output type (#32178)
* [whisper] fix short-form output type

* add test

* make style

* update long-form tests

* fixes

* last fix

* finalise test
2024-07-26 17:01:59 +02:00
Lysandre
38d94bffa6 Patch release
Some checks failed
Release - Conda / build_and_package (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
2024-07-24 17:42:52 +02:00
Marc Sun
b4a0442dbd Fix float8_e4m3fn in modeling_utils (#32193)
* Fix float8_e4m3fn in modeling_utils

* style

* fix

* comment
2024-07-24 17:42:35 +02:00
Raushan Turganbay
4672b4d79b Fix resize embedding with Deepspeed (#32192)
fix resize when deepspeed
2024-07-24 17:42:28 +02:00
Arthur
a2b6a001c0 let's not warn when someone is running a forward (#32176)
* let's not warn when someone is running a foward without cache + self.training

* more models

* fixup
2024-07-24 17:42:19 +02:00
Joao Gante
64a90d72a8 RoPE: relaxed rope validation (#32182)
* relaxed rope check

* lets also accept rope_type=None, defaulting to the original implementation

* type and rope_type can coexist
2024-07-24 17:42:14 +02:00
26 changed files with 241 additions and 108 deletions

View File

@@ -430,7 +430,7 @@ install_requires = [
setup(
name="transformers",
version="4.43.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.43.3", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",

View File

@@ -18,7 +18,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.43.1"
__version__ = "4.43.3"
from typing import TYPE_CHECKING

View File

@@ -26,6 +26,7 @@ import shutil
import sys
import tempfile
from dataclasses import asdict, fields
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
@@ -726,6 +727,35 @@ def save_model_architecture_to_file(model: Any, output_dir: str):
print(model, file=f)
class WandbLogModel(str, Enum):
"""Enum of possible log model values in W&B."""
CHECKPOINT = "checkpoint"
END = "end"
FALSE = "false"
@property
def is_enabled(self) -> bool:
"""Check if the value corresponds to a state where the `WANDB_LOG_MODEL` setting is enabled."""
return self in (WandbLogModel.CHECKPOINT, WandbLogModel.END)
@classmethod
def _missing_(cls, value: Any) -> "WandbLogModel":
if not isinstance(value, str):
raise ValueError(f"Expecting to have a string `WANDB_LOG_MODEL` setting, but got {type(value)}")
if value.upper() in ENV_VARS_TRUE_VALUES:
DeprecationWarning(
f"Setting `WANDB_LOG_MODEL` as {os.getenv('WANDB_LOG_MODEL')} is deprecated and will be removed in "
"version 5 of transformers. Use one of `'end'` or `'checkpoint'` instead."
)
logger.info(f"Setting `WANDB_LOG_MODEL` from {os.getenv('WANDB_LOG_MODEL')} to `end` instead")
return WandbLogModel.END
logger.warning(
f"Received unrecognized `WANDB_LOG_MODEL` setting value={value}; so disabling `WANDB_LOG_MODEL`"
)
return WandbLogModel.FALSE
class WandbCallback(TrainerCallback):
"""
A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/).
@@ -740,16 +770,7 @@ class WandbCallback(TrainerCallback):
self._wandb = wandb
self._initialized = False
# log model
if os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}):
DeprecationWarning(
f"Setting `WANDB_LOG_MODEL` as {os.getenv('WANDB_LOG_MODEL')} is deprecated and will be removed in "
"version 5 of transformers. Use one of `'end'` or `'checkpoint'` instead."
)
logger.info(f"Setting `WANDB_LOG_MODEL` from {os.getenv('WANDB_LOG_MODEL')} to `end` instead")
self._log_model = "end"
else:
self._log_model = os.getenv("WANDB_LOG_MODEL", "false").lower()
self._log_model = WandbLogModel(os.getenv("WANDB_LOG_MODEL", "false"))
def setup(self, args, state, model, **kwargs):
"""
@@ -834,37 +855,38 @@ class WandbCallback(TrainerCallback):
logger.info("Could not log the number of model parameters in Weights & Biases.")
# log the initial model architecture to an artifact
with tempfile.TemporaryDirectory() as temp_dir:
model_name = (
f"model-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir)
else f"model-{self._wandb.run.name}"
)
model_artifact = self._wandb.Artifact(
name=model_name,
type="model",
metadata={
"model_config": model.config.to_dict() if hasattr(model, "config") else None,
"num_parameters": self._wandb.config.get("model/num_parameters"),
"initial_model": True,
},
)
# add the architecture to a separate text file
save_model_architecture_to_file(model, temp_dir)
if self._log_model.is_enabled:
with tempfile.TemporaryDirectory() as temp_dir:
model_name = (
f"model-{self._wandb.run.id}"
if (args.run_name is None or args.run_name == args.output_dir)
else f"model-{self._wandb.run.name}"
)
model_artifact = self._wandb.Artifact(
name=model_name,
type="model",
metadata={
"model_config": model.config.to_dict() if hasattr(model, "config") else None,
"num_parameters": self._wandb.config.get("model/num_parameters"),
"initial_model": True,
},
)
# add the architecture to a separate text file
save_model_architecture_to_file(model, temp_dir)
for f in Path(temp_dir).glob("*"):
if f.is_file():
with model_artifact.new_file(f.name, mode="wb") as fa:
fa.write(f.read_bytes())
self._wandb.run.log_artifact(model_artifact, aliases=["base_model"])
for f in Path(temp_dir).glob("*"):
if f.is_file():
with model_artifact.new_file(f.name, mode="wb") as fa:
fa.write(f.read_bytes())
self._wandb.run.log_artifact(model_artifact, aliases=["base_model"])
badge_markdown = (
f'[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge'
f'-28.svg" alt="Visualize in Weights & Biases" width="20'
f'0" height="32"/>]({self._wandb.run.get_url()})'
)
badge_markdown = (
f'[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge'
f'-28.svg" alt="Visualize in Weights & Biases" width="20'
f'0" height="32"/>]({self._wandb.run.get_url()})'
)
modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
def on_train_begin(self, args, state, control, model=None, **kwargs):
if self._wandb is None:
@@ -880,7 +902,7 @@ class WandbCallback(TrainerCallback):
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
if self._wandb is None:
return
if self._log_model in ("end", "checkpoint") and self._initialized and state.is_world_process_zero:
if self._log_model.is_enabled and self._initialized and state.is_world_process_zero:
from ..trainer import Trainer
fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
@@ -938,7 +960,7 @@ class WandbCallback(TrainerCallback):
self._wandb.log({**non_scalar_logs, "train/global_step": state.global_step})
def on_save(self, args, state, control, **kwargs):
if self._log_model == "checkpoint" and self._initialized and state.is_world_process_zero:
if self._log_model == WandbLogModel.CHECKPOINT and self._initialized and state.is_world_process_zero:
checkpoint_metadata = {
k: v
for k, v in dict(self._wandb.summary).items()

View File

@@ -354,6 +354,11 @@ ROPE_INIT_FUNCTIONS = {
def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None):
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
# BC: "rope_type" was originally "type" -- let's gracefully handle it
if "rope_type" not in received_keys and "type" in received_keys:
received_keys -= {"type"}
received_keys.add("rope_type")
missing_keys = required_keys - received_keys
if missing_keys:
raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}")
@@ -361,14 +366,14 @@ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set,
if optional_keys is not None:
unused_keys = received_keys - required_keys - optional_keys
else:
unused_keys = received_keys - received_keys
unused_keys = received_keys - required_keys
if unused_keys:
raise KeyError(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
def _validate_default_rope_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"]
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys)
@@ -376,19 +381,19 @@ def _validate_default_rope_parameters(config: PretrainedConfig):
def _validate_linear_scaling_rope_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"]
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys)
factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"]
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys = {"original_max_position_embeddings"}
@@ -397,12 +402,12 @@ def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
def _validate_yarn_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"]
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor"}
optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
received_keys = set(rope_scaling.keys())
@@ -410,22 +415,22 @@ def _validate_yarn_parameters(config: PretrainedConfig):
factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
attention_factor = rope_scaling.get("attention_factor")
if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0):
raise ValueError(
logger.warning(
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
)
beta_fast = rope_scaling.get("beta_fast")
if beta_fast is not None and not isinstance(beta_fast, float):
raise ValueError(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
beta_slow = rope_scaling.get("beta_slow")
if beta_slow is not None and not isinstance(beta_slow, float):
raise ValueError(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
if (beta_fast or 32) < (beta_slow or 1):
raise ValueError(
logger.warning(
f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
)
@@ -433,7 +438,7 @@ def _validate_yarn_parameters(config: PretrainedConfig):
def _validate_longrope_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"]
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "short_factor", "long_factor"}
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
@@ -445,15 +450,15 @@ def _validate_longrope_parameters(config: PretrainedConfig):
short_factor = rope_scaling.get("short_factor")
if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
raise ValueError(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}")
logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}")
if not len(short_factor) == dim // 2:
raise ValueError(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}")
logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}")
long_factor = rope_scaling.get("long_factor")
if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor):
raise ValueError(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}")
logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}")
if not len(long_factor) == dim // 2:
raise ValueError(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}")
logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}")
# Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
# `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is
@@ -468,48 +473,48 @@ def _validate_longrope_parameters(config: PretrainedConfig):
else:
factor = rope_scaling.get("factor")
if factor is None:
raise ValueError("Missing required keys in `rope_scaling`: 'factor'")
logger.warning("Missing required keys in `rope_scaling`: 'factor'")
elif not isinstance(factor, float) or factor < 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
attention_factor = rope_scaling.get("attention_factor")
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0:
raise ValueError(
logger.warning(
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
)
def _validate_llama3_parameters(config: PretrainedConfig):
rope_scaling = config.rope_scaling
rope_type = rope_scaling["rope_type"]
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys)
factor = rope_scaling["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"]
if low_freq_factor is None or not isinstance(low_freq_factor, float):
raise ValueError(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
if high_freq_factor is None or not isinstance(high_freq_factor, float):
raise ValueError(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
if high_freq_factor < low_freq_factor:
raise ValueError(
logger.warning(
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
)
original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):
raise ValueError(
logger.warning(
"`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
f"{original_max_position_embeddings}"
)
if original_max_position_embeddings >= config.max_position_embeddings:
raise ValueError(
logger.warning(
"`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
)
@@ -534,17 +539,12 @@ def rope_config_validation(config: PretrainedConfig):
if rope_scaling is None:
return
possible_rope_types = set(ROPE_INIT_FUNCTIONS.keys())
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
if rope_type is None:
raise ValueError(
f"rope_scaling must contain a non-None 'rope_type' field. Possible options are {possible_rope_types}"
)
# BC: "rope_type" was originally "type"
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
if validation_fn is not None:
validation_fn(config)
else:
raise ValueError(
logger.warning(
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
)

View File

@@ -855,6 +855,8 @@ def _load_state_dict_into_meta_model(
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
for param_name, param in state_dict.items():
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
if param_name not in loaded_state_dict_keys or param_name not in expected_keys:
@@ -866,9 +868,10 @@ def _load_state_dict_into_meta_model(
module_name = param_name
set_module_kwargs = {}
# We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
# in int/uint/bool and not cast them.
if dtype is not None and torch.is_floating_point(param) and param.dtype != torch.float8_e4m3fn:
is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn
if dtype is not None and torch.is_floating_point(param) and not is_param_float8_e4m3fn:
if (
keep_in_fp32_modules is not None
and any(
@@ -2131,13 +2134,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Replace weights in old_embeddings and return to maintain the same embedding type.
# This ensures correct functionality when a Custom Embedding class is passed as input.
# The input and output embedding types remain consistent. (c.f. https://github.com/huggingface/transformers/pull/31979)
old_embeddings.weight.data = new_embeddings.weight.data
old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0]
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed
# If the new number of tokens is smaller than the original `padding_idx`, the `padding_idx`
# will be set to `None` in the resized embeddings.
if old_embeddings.padding_idx is not None and (new_num_tokens - 1) < old_embeddings.padding_idx:
old_embeddings.padding_idx = None
params = [old_embeddings.weight, new_embeddings.weight]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
old_embeddings.weight.data = new_embeddings.weight.data
old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0]
# If the new number of tokens is smaller than the original `padding_idx`, the `padding_idx`
# will be set to `None` in the resized embeddings.
if old_embeddings.padding_idx is not None and (new_num_tokens - 1) < old_embeddings.padding_idx:
old_embeddings.padding_idx = None
else:
old_embeddings.weight.data = new_embeddings.weight.data
old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0]
if old_embeddings.padding_idx is not None and (new_num_tokens - 1) < old_embeddings.padding_idx:
old_embeddings.padding_idx = None
return old_embeddings

View File

@@ -1569,6 +1569,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["BigBirdPegasusEncoderLayer", "BigBirdPegasusDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_param_buffer_assignment = False
def _init_weights(self, module):
std = self.config.init_std

View File

@@ -769,7 +769,9 @@ class CohereModel(CoherePreTrainedModel):
past_seen_tokens = 0
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

View File

@@ -1004,7 +1004,9 @@ class DbrxModel(DbrxPreTrainedModel):
inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(

View File

@@ -474,7 +474,9 @@ class GemmaModel(LlamaModel):
inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False # noqa: F841
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True # noqa: F841
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

View File

@@ -769,7 +769,9 @@ class GemmaModel(GemmaPreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False # noqa: F841
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True # noqa: F841
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
@@ -794,7 +796,9 @@ class GemmaModel(GemmaPreTrainedModel):
# See https://github.com/huggingface/transformers/pull/29402
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(

View File

@@ -978,7 +978,9 @@ class JetMoeModel(JetMoePreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

View File

@@ -189,6 +189,9 @@ class LlamaConfig(PretrainedConfig):
self.mlp_bias = mlp_bias
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__(

View File

@@ -893,7 +893,9 @@ class LlamaModel(LlamaPreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(

View File

@@ -757,7 +757,7 @@ class MistralModel(MistralPreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
return_legacy_cache = True
logger.warning_once(

View File

@@ -959,7 +959,7 @@ class MixtralModel(MixtralPreTrainedModel):
use_cache = False
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(

View File

@@ -810,7 +810,9 @@ class OlmoModel(OlmoPreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(

View File

@@ -626,7 +626,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
use_cache = False
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(

View File

@@ -908,7 +908,7 @@ class PhiModel(PhiPreTrainedModel):
use_cache = False
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(

View File

@@ -949,7 +949,7 @@ class Phi3Model(Phi3PreTrainedModel):
use_cache = False
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(

View File

@@ -807,7 +807,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
use_cache = False
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(

View File

@@ -969,7 +969,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel):
use_cache = False
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(

View File

@@ -901,7 +901,7 @@ class StableLmModel(StableLmPreTrainedModel):
use_cache = False
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(

View File

@@ -783,7 +783,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel):
use_cache = False
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(

View File

@@ -498,7 +498,7 @@ class WhisperGenerationMixin:
# 3. Make sure generation config is correctly set
# Make sure the generation config is correctly set depending on whether timestamps are to be returned or not
self._set_return_outputs(
return_dict_in_generate = self._set_return_outputs(
return_dict_in_generate=return_dict_in_generate,
return_token_timestamps=return_token_timestamps,
logprob_threshold=logprob_threshold,
@@ -732,7 +732,7 @@ class WhisperGenerationMixin:
else:
outputs = sequences
if generation_config.return_dict_in_generate:
if return_dict_in_generate and generation_config.return_dict_in_generate:
dict_outputs = self._stack_split_outputs(seek_outputs, model_output_type, sequences.device, kwargs)
if num_return_sequences > 1:
@@ -1109,18 +1109,20 @@ class WhisperGenerationMixin:
def _set_return_outputs(return_dict_in_generate, return_token_timestamps, logprob_threshold, generation_config):
if return_dict_in_generate is None:
return_dict_in_generate = generation_config.return_dict_in_generate
else:
generation_config.return_dict_in_generate = return_dict_in_generate
generation_config.return_token_timestamps = return_token_timestamps
if return_token_timestamps:
return_dict_in_generate = True
generation_config.return_dict_in_generate = True
generation_config.output_attentions = True
generation_config.output_scores = True
if logprob_threshold is not None:
return_dict_in_generate = True
generation_config.return_dict_in_generate = True
generation_config.output_scores = True
generation_config.return_dict_in_generate = return_dict_in_generate
return return_dict_in_generate
def _set_return_timestamps(self, return_timestamps, is_shortform, generation_config):
if not is_shortform:

View File

@@ -526,6 +526,60 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
torch.testing.assert_close(old_cos_long, new_cos_long)
torch.testing.assert_close(old_sin_long, new_sin_long)
def test_model_loading_old_rope_configs(self):
def _reinitialize_config(base_config, new_kwargs):
# Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation
# steps.
base_config_dict = base_config.to_dict()
new_config = LlamaConfig.from_dict(config_dict={**base_config_dict, **new_kwargs})
return new_config
# from untouched config -> ✅
base_config, model_inputs = self.model_tester.prepare_config_and_inputs_for_common()
original_model = LlamaForCausalLM(base_config).to(torch_device)
original_model(**model_inputs)
# from a config with the expected rope configuration -> ✅
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0}})
original_model = LlamaForCausalLM(config).to(torch_device)
original_model(**model_inputs)
# from a config with the old rope configuration ('type' instead of 'rope_type') -> ✅ we gracefully handle BC
config = _reinitialize_config(base_config, {"rope_scaling": {"type": "linear", "factor": 10.0}})
original_model = LlamaForCausalLM(config).to(torch_device)
original_model(**model_inputs)
# from a config with both 'type' and 'rope_type' -> ✅ they can coexist (and both are present in the config)
config = _reinitialize_config(
base_config, {"rope_scaling": {"type": "linear", "rope_type": "linear", "factor": 10.0}}
)
self.assertTrue(config.rope_scaling["type"] == "linear")
self.assertTrue(config.rope_scaling["rope_type"] == "linear")
original_model = LlamaForCausalLM(config).to(torch_device)
original_model(**model_inputs)
# from a config with parameters in a bad range ('factor' should be >= 1.0) -> ⚠️ throws a warning
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": -999.0}})
original_model = LlamaForCausalLM(config).to(torch_device)
original_model(**model_inputs)
self.assertEqual(len(logs.output), 1)
self.assertIn("factor field", logs.output[0])
# from a config with unknown parameters ('foo' isn't a rope option) -> ⚠️ throws a warning
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
config = _reinitialize_config(
base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0, "foo": "bar"}}
)
original_model = LlamaForCausalLM(config).to(torch_device)
original_model(**model_inputs)
self.assertEqual(len(logs.output), 1)
self.assertIn("Unrecognized keys", logs.output[0])
# from a config with specific rope type but missing one of its mandatory parameters -> ❌ throws exception
with self.assertRaises(KeyError):
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor"
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes

View File

@@ -26,6 +26,7 @@ import unittest
import numpy as np
import pytest
from huggingface_hub import hf_hub_download
from parameterized import parameterized
import transformers
from transformers import WhisperConfig
@@ -72,6 +73,7 @@ if is_torch_available():
BeamSearchEncoderDecoderOutput,
GenerateBeamDecoderOnlyOutput,
GenerateBeamEncoderDecoderOutput,
GenerateEncoderDecoderOutput,
PhrasalConstraint,
)
from transformers.generation.logits_process import LogitsProcessor
@@ -1820,6 +1822,26 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
normalized_1 = torch.nn.functional.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
@parameterized.expand([(True,), (False,)])
def test_generate_output_type(self, return_dict_in_generate):
expected_output_type = GenerateEncoderDecoderOutput if return_dict_in_generate else torch.Tensor
for model_class in self.all_generative_model_classes:
config, inputs = self.model_tester.prepare_config_and_inputs()
model = model_class(config).to(torch_device).eval()
# short-form generation without fallback
pred_ids = model.generate(**inputs, return_dict_in_generate=return_dict_in_generate)
assert isinstance(pred_ids, expected_output_type)
# short-form generation with fallback
pred_ids = model.generate(
**inputs,
logprob_threshold=-1.0,
temperature=[0.0, 0.1],
return_dict_in_generate=return_dict_in_generate,
)
assert isinstance(pred_ids, expected_output_type)
@require_torch
@require_torchaudio