Ignore keys on validate_rope (#33753)
* ignore keys on check rope * add tests * fix tests, so maybe better leave at logger lvl
This commit is contained in:
committed by
GitHub
parent
4a173b88b5
commit
061c2c4c38
@@ -360,13 +360,23 @@ ROPE_INIT_FUNCTIONS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _check_received_keys(rope_type: str, received_keys: set, required_keys: set, optional_keys: Optional[set] = None):
|
def _check_received_keys(
|
||||||
|
rope_type: str,
|
||||||
|
received_keys: set,
|
||||||
|
required_keys: set,
|
||||||
|
optional_keys: Optional[set] = None,
|
||||||
|
ignore_keys: Optional[set] = None,
|
||||||
|
):
|
||||||
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
|
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
|
||||||
# BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
|
# BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
|
||||||
if "type" in received_keys:
|
if "type" in received_keys:
|
||||||
received_keys -= {"type"}
|
received_keys -= {"type"}
|
||||||
required_keys.add("rope_type")
|
required_keys.add("rope_type")
|
||||||
|
|
||||||
|
# Some models need to store model-specific keys, and we don't want to throw warning at them
|
||||||
|
if ignore_keys is not None:
|
||||||
|
received_keys -= ignore_keys
|
||||||
|
|
||||||
missing_keys = required_keys - received_keys
|
missing_keys = required_keys - received_keys
|
||||||
if missing_keys:
|
if missing_keys:
|
||||||
raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}")
|
raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}")
|
||||||
@@ -379,47 +389,47 @@ def _check_received_keys(rope_type: str, received_keys: set, required_keys: set,
|
|||||||
logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
|
logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
|
||||||
|
|
||||||
|
|
||||||
def _validate_default_rope_parameters(config: PretrainedConfig):
|
def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||||
rope_scaling = config.rope_scaling
|
rope_scaling = config.rope_scaling
|
||||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||||
required_keys = {"rope_type"}
|
required_keys = {"rope_type"}
|
||||||
received_keys = set(rope_scaling.keys())
|
received_keys = set(rope_scaling.keys())
|
||||||
_check_received_keys(rope_type, received_keys, required_keys)
|
_check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
|
||||||
|
|
||||||
|
|
||||||
def _validate_linear_scaling_rope_parameters(config: PretrainedConfig):
|
def _validate_linear_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||||
rope_scaling = config.rope_scaling
|
rope_scaling = config.rope_scaling
|
||||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||||
required_keys = {"rope_type", "factor"}
|
required_keys = {"rope_type", "factor"}
|
||||||
received_keys = set(rope_scaling.keys())
|
received_keys = set(rope_scaling.keys())
|
||||||
_check_received_keys(rope_type, received_keys, required_keys)
|
_check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
|
||||||
|
|
||||||
factor = rope_scaling["factor"]
|
factor = rope_scaling["factor"]
|
||||||
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
||||||
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||||
|
|
||||||
|
|
||||||
def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig):
|
def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||||
rope_scaling = config.rope_scaling
|
rope_scaling = config.rope_scaling
|
||||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||||
required_keys = {"rope_type", "factor"}
|
required_keys = {"rope_type", "factor"}
|
||||||
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
||||||
optional_keys = {"original_max_position_embeddings"}
|
optional_keys = {"original_max_position_embeddings"}
|
||||||
received_keys = set(rope_scaling.keys())
|
received_keys = set(rope_scaling.keys())
|
||||||
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
|
||||||
|
|
||||||
factor = rope_scaling["factor"]
|
factor = rope_scaling["factor"]
|
||||||
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
||||||
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||||
|
|
||||||
|
|
||||||
def _validate_yarn_parameters(config: PretrainedConfig):
|
def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||||
rope_scaling = config.rope_scaling
|
rope_scaling = config.rope_scaling
|
||||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||||
required_keys = {"rope_type", "factor"}
|
required_keys = {"rope_type", "factor"}
|
||||||
optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
|
optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
|
||||||
received_keys = set(rope_scaling.keys())
|
received_keys = set(rope_scaling.keys())
|
||||||
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
|
||||||
|
|
||||||
factor = rope_scaling["factor"]
|
factor = rope_scaling["factor"]
|
||||||
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
||||||
@@ -444,14 +454,14 @@ def _validate_yarn_parameters(config: PretrainedConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _validate_longrope_parameters(config: PretrainedConfig):
|
def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||||
rope_scaling = config.rope_scaling
|
rope_scaling = config.rope_scaling
|
||||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||||
required_keys = {"rope_type", "short_factor", "long_factor"}
|
required_keys = {"rope_type", "short_factor", "long_factor"}
|
||||||
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
# TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
|
||||||
optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
|
optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
|
||||||
received_keys = set(rope_scaling.keys())
|
received_keys = set(rope_scaling.keys())
|
||||||
_check_received_keys(rope_type, received_keys, required_keys, optional_keys)
|
_check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
|
||||||
|
|
||||||
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
||||||
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||||
@@ -494,12 +504,12 @@ def _validate_longrope_parameters(config: PretrainedConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _validate_llama3_parameters(config: PretrainedConfig):
|
def _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||||
rope_scaling = config.rope_scaling
|
rope_scaling = config.rope_scaling
|
||||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
|
||||||
required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
|
required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
|
||||||
received_keys = set(rope_scaling.keys())
|
received_keys = set(rope_scaling.keys())
|
||||||
_check_received_keys(rope_type, received_keys, required_keys)
|
_check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
|
||||||
|
|
||||||
factor = rope_scaling["factor"]
|
factor = rope_scaling["factor"]
|
||||||
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
if factor is None or not isinstance(factor, float) or factor < 1.0:
|
||||||
@@ -541,7 +551,7 @@ ROPE_VALIDATION_FUNCTIONS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def rope_config_validation(config: PretrainedConfig):
|
def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None):
|
||||||
"""
|
"""
|
||||||
Validate the RoPE config arguments, given a `PretrainedConfig` object
|
Validate the RoPE config arguments, given a `PretrainedConfig` object
|
||||||
"""
|
"""
|
||||||
@@ -553,7 +563,7 @@ def rope_config_validation(config: PretrainedConfig):
|
|||||||
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
|
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
|
||||||
validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
|
validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
|
||||||
if validation_fn is not None:
|
if validation_fn is not None:
|
||||||
validation_fn(config)
|
validation_fn(config, ignore_keys=ignore_keys)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
|
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
|
||||||
|
|||||||
@@ -235,11 +235,13 @@ class Qwen2VLConfig(PretrainedConfig):
|
|||||||
|
|
||||||
# Validate the correctness of rotary position embeddings parameters
|
# Validate the correctness of rotary position embeddings parameters
|
||||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||||
# and change type from 'mrope' to 'default'
|
# and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations
|
||||||
|
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
|
||||||
|
# TODO: @raushan update config in the hub
|
||||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||||
if self.rope_scaling["type"] == "mrope":
|
if self.rope_scaling["type"] == "mrope":
|
||||||
self.rope_scaling["type"] = "default"
|
self.rope_scaling["type"] = "default"
|
||||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||||
rope_config_validation(self)
|
rope_config_validation(self, ignore_keys={"mrope_section"})
|
||||||
|
|
||||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||||
|
|||||||
@@ -65,6 +65,19 @@ class RopeTest(unittest.TestCase):
|
|||||||
with self.assertRaises(KeyError):
|
with self.assertRaises(KeyError):
|
||||||
rope_config_validation(config)
|
rope_config_validation(config)
|
||||||
|
|
||||||
|
# Any other parameters passed to RoPE will raise a warning that a particular key is not used
|
||||||
|
# But sometimes we can have model-specific RoPE kwargs and bypass warning with `ignore_keys`
|
||||||
|
model_specific_kwarg = "mrope_sections" # e,g in Qwen2-VL
|
||||||
|
|
||||||
|
for rope_type in all_rope_types:
|
||||||
|
if rope_type == "default":
|
||||||
|
config.rope_scaling = {"rope_type": rope_type, model_specific_kwarg: True}
|
||||||
|
rope_config_validation(config, ignore_keys={model_specific_kwarg})
|
||||||
|
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
|
||||||
|
rope_config_validation(config)
|
||||||
|
self.assertEqual(len(logs.output), 1)
|
||||||
|
self.assertIn(model_specific_kwarg, logs.output[0])
|
||||||
|
|
||||||
def test_default_rope_function_bc(self):
|
def test_default_rope_function_bc(self):
|
||||||
config = LlamaConfig()
|
config = LlamaConfig()
|
||||||
device = torch_device
|
device = torch_device
|
||||||
|
|||||||
Reference in New Issue
Block a user