RoPE: relaxed rope validation (#32182)
* relaxed rope check * lets also accept rope_type=None, defaulting to the original implementation * type and rope_type can coexist
This commit is contained in:
@@ -354,6 +354,11 @@ ROPE_INIT_FUNCTIONS = {
|
||||
|
||||
def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None):
|
||||
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
|
||||
# BC: "rope_type" was originally "type" -- let's gracefully handle it
|
||||
if "rope_type" not in received_keys and "type" in received_keys:
|
||||
received_keys -= {"type"}
|
||||
received_keys.add("rope_type")
|
||||
|
||||
missing_keys = required_keys - received_keys
|
||||
if missing_keys:
|
||||
raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}")
|
||||
@@ -361,14 +366,14 @@ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set,
|
||||
if optional_keys is not None:
|
||||
unused_keys = received_keys - required_keys - optional_keys
|
||||
else:
|
||||
unused_keys = received_keys - received_keys
|
||||
unused_keys = received_keys - required_keys
|
||||
if unused_keys:
|
||||
raise KeyError(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
|
||||
logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
|
||||
|
||||
|
||||
def _validate_default_rope_parameters(config: PretrainedConfig):
|
||||
rope_scaling = config.rope_scaling
|
||||
rope_type = rope_scaling["rope_type"]
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||
required_keys = {"rope_type"}
|
||||
received_keys = set(rope_scaling.keys())
|
||||
_check_received_keys(rope_type, received_keys, required_keys)
|
||||
@@ -376,19 +381,19 @@ def _validate_default_rope_parameters(config: PretrainedConfig):
|
||||
|
||||
def _validate_linear_scaling_rope_parameters(config: PretrainedConfig):
|
||||
rope_scaling = config.rope_scaling
|
||||
rope_type = rope_scaling["rope_type"]
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||
required_keys = {"rope_type", "factor"}
|
||||
received_keys = set(rope_scaling.keys())
|
||||
_check_received_keys(rope_type, received_keys, required_keys)
|
||||
|
||||
factor = rope_scaling["factor"]
|
||||
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
||||
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||
|
||||
|
||||
def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
|
||||
rope_scaling = config.rope_scaling
|
||||
rope_type = rope_scaling["rope_type"]
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||
required_keys = {"rope_type", "factor"}
|
||||
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
||||
optional_keys = {"original_max_position_embeddings"}
|
||||
@@ -397,12 +402,12 @@ def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
|
||||
|
||||
factor = rope_scaling["factor"]
|
||||
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
||||
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||
|
||||
|
||||
def _validate_yarn_parameters(config: PretrainedConfig):
|
||||
rope_scaling = config.rope_scaling
|
||||
rope_type = rope_scaling["rope_type"]
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||
required_keys = {"rope_type", "factor"}
|
||||
optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
|
||||
received_keys = set(rope_scaling.keys())
|
||||
@@ -410,22 +415,22 @@ def _validate_yarn_parameters(config: PretrainedConfig):
|
||||
|
||||
factor = rope_scaling["factor"]
|
||||
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
||||
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||
|
||||
attention_factor = rope_scaling.get("attention_factor")
|
||||
if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0):
|
||||
raise ValueError(
|
||||
logger.warning(
|
||||
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
|
||||
)
|
||||
beta_fast = rope_scaling.get("beta_fast")
|
||||
if beta_fast is not None and not isinstance(beta_fast, float):
|
||||
raise ValueError(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
|
||||
logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
|
||||
beta_slow = rope_scaling.get("beta_slow")
|
||||
if beta_slow is not None and not isinstance(beta_slow, float):
|
||||
raise ValueError(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
|
||||
logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
|
||||
|
||||
if (beta_fast or 32) < (beta_slow or 1):
|
||||
raise ValueError(
|
||||
logger.warning(
|
||||
f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
|
||||
f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
|
||||
)
|
||||
@@ -433,7 +438,7 @@ def _validate_yarn_parameters(config: PretrainedConfig):
|
||||
|
||||
def _validate_longrope_parameters(config: PretrainedConfig):
|
||||
rope_scaling = config.rope_scaling
|
||||
rope_type = rope_scaling["rope_type"]
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||
required_keys = {"rope_type", "short_factor", "long_factor"}
|
||||
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
||||
optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
|
||||
@@ -445,15 +450,15 @@ def _validate_longrope_parameters(config: PretrainedConfig):
|
||||
|
||||
short_factor = rope_scaling.get("short_factor")
|
||||
if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
|
||||
raise ValueError(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}")
|
||||
logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}")
|
||||
if not len(short_factor) == dim // 2:
|
||||
raise ValueError(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}")
|
||||
logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}")
|
||||
|
||||
long_factor = rope_scaling.get("long_factor")
|
||||
if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor):
|
||||
raise ValueError(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}")
|
||||
logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}")
|
||||
if not len(long_factor) == dim // 2:
|
||||
raise ValueError(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}")
|
||||
logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}")
|
||||
|
||||
# Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
|
||||
# `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is
|
||||
@@ -468,48 +473,48 @@ def _validate_longrope_parameters(config: PretrainedConfig):
|
||||
else:
|
||||
factor = rope_scaling.get("factor")
|
||||
if factor is None:
|
||||
raise ValueError("Missing required keys in `rope_scaling`: 'factor'")
|
||||
logger.warning("Missing required keys in `rope_scaling`: 'factor'")
|
||||
elif not isinstance(factor, float) or factor < 1.0:
|
||||
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||
|
||||
attention_factor = rope_scaling.get("attention_factor")
|
||||
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0:
|
||||
raise ValueError(
|
||||
logger.warning(
|
||||
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
|
||||
)
|
||||
|
||||
|
||||
def _validate_llama3_parameters(config: PretrainedConfig):
|
||||
rope_scaling = config.rope_scaling
|
||||
rope_type = rope_scaling["rope_type"]
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||
required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
|
||||
received_keys = set(rope_scaling.keys())
|
||||
_check_received_keys(rope_type, received_keys, required_keys)
|
||||
|
||||
factor = rope_scaling["factor"]
|
||||
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
||||
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||
|
||||
low_freq_factor = rope_scaling["low_freq_factor"]
|
||||
high_freq_factor = rope_scaling["high_freq_factor"]
|
||||
if low_freq_factor is None or not isinstance(low_freq_factor, float):
|
||||
raise ValueError(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
|
||||
logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
|
||||
if high_freq_factor is None or not isinstance(high_freq_factor, float):
|
||||
raise ValueError(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
|
||||
logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
|
||||
if high_freq_factor < low_freq_factor:
|
||||
raise ValueError(
|
||||
logger.warning(
|
||||
"`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
|
||||
f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
|
||||
)
|
||||
|
||||
original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
|
||||
if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):
|
||||
raise ValueError(
|
||||
logger.warning(
|
||||
"`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
|
||||
f"{original_max_position_embeddings}"
|
||||
)
|
||||
if original_max_position_embeddings >= config.max_position_embeddings:
|
||||
raise ValueError(
|
||||
logger.warning(
|
||||
"`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
|
||||
f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
|
||||
)
|
||||
@@ -534,17 +539,12 @@ def rope_config_validation(config: PretrainedConfig):
|
||||
if rope_scaling is None:
|
||||
return
|
||||
|
||||
possible_rope_types = set(ROPE_INIT_FUNCTIONS.keys())
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||
if rope_type is None:
|
||||
raise ValueError(
|
||||
f"rope_scaling must contain a non-None 'rope_type' field. Possible options are {possible_rope_types}"
|
||||
)
|
||||
|
||||
# BC: "rope_type" was originally "type"
|
||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
|
||||
validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
|
||||
if validation_fn is not None:
|
||||
validation_fn(config)
|
||||
else:
|
||||
raise ValueError(
|
||||
logger.warning(
|
||||
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
|
||||
)
|
||||
|
||||
@@ -189,6 +189,9 @@ class LlamaConfig(PretrainedConfig):
|
||||
self.mlp_bias = mlp_bias
|
||||
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
super().__init__(
|
||||
|
||||
@@ -526,6 +526,60 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
torch.testing.assert_close(old_cos_long, new_cos_long)
|
||||
torch.testing.assert_close(old_sin_long, new_sin_long)
|
||||
|
||||
def test_model_loading_old_rope_configs(self):
|
||||
def _reinitialize_config(base_config, new_kwargs):
|
||||
# Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation
|
||||
# steps.
|
||||
base_config_dict = base_config.to_dict()
|
||||
new_config = LlamaConfig.from_dict(config_dict={**base_config_dict, **new_kwargs})
|
||||
return new_config
|
||||
|
||||
# from untouched config -> ✅
|
||||
base_config, model_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
original_model = LlamaForCausalLM(base_config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
|
||||
# from a config with the expected rope configuration -> ✅
|
||||
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0}})
|
||||
original_model = LlamaForCausalLM(config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
|
||||
# from a config with the old rope configuration ('type' instead of 'rope_type') -> ✅ we gracefully handle BC
|
||||
config = _reinitialize_config(base_config, {"rope_scaling": {"type": "linear", "factor": 10.0}})
|
||||
original_model = LlamaForCausalLM(config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
|
||||
# from a config with both 'type' and 'rope_type' -> ✅ they can coexist (and both are present in the config)
|
||||
config = _reinitialize_config(
|
||||
base_config, {"rope_scaling": {"type": "linear", "rope_type": "linear", "factor": 10.0}}
|
||||
)
|
||||
self.assertTrue(config.rope_scaling["type"] == "linear")
|
||||
self.assertTrue(config.rope_scaling["rope_type"] == "linear")
|
||||
original_model = LlamaForCausalLM(config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
|
||||
# from a config with parameters in a bad range ('factor' should be >= 1.0) -> ⚠️ throws a warning
|
||||
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
|
||||
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": -999.0}})
|
||||
original_model = LlamaForCausalLM(config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
self.assertEqual(len(logs.output), 1)
|
||||
self.assertIn("factor field", logs.output[0])
|
||||
|
||||
# from a config with unknown parameters ('foo' isn't a rope option) -> ⚠️ throws a warning
|
||||
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
|
||||
config = _reinitialize_config(
|
||||
base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0, "foo": "bar"}}
|
||||
)
|
||||
original_model = LlamaForCausalLM(config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
self.assertEqual(len(logs.output), 1)
|
||||
self.assertIn("Unrecognized keys", logs.output[0])
|
||||
|
||||
# from a config with specific rope type but missing one of its mandatory parameters -> ❌ throws exception
|
||||
with self.assertRaises(KeyError):
|
||||
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor"
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@require_bitsandbytes
|
||||
|
||||
Reference in New Issue
Block a user