Bugfix/alexsherstinsky/fix none check for attention factor in rope scaling 2024 08 28 0 (#33188)
* Fixing a bug in the way "attention_factor" is validated in ROPE utilities. * Fixing a bug in the way "attention_factor" is validated in ROPE utilities. * Fixing a bug in the way "attention_factor" is validated in ROPE utilities.
This commit is contained in:
@@ -487,10 +487,11 @@ def _validate_longrope_parameters(config: PretrainedConfig):
|
|||||||
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
|
||||||
|
|
||||||
attention_factor = rope_scaling.get("attention_factor")
|
attention_factor = rope_scaling.get("attention_factor")
|
||||||
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0:
|
if attention_factor is not None:
|
||||||
logger.warning(
|
if not isinstance(attention_factor, float) or attention_factor < 0.0:
|
||||||
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
|
logger.warning(
|
||||||
)
|
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _validate_llama3_parameters(config: PretrainedConfig):
|
def _validate_llama3_parameters(config: PretrainedConfig):
|
||||||
|
|||||||
@@ -330,6 +330,16 @@ class RopeTest(unittest.TestCase):
|
|||||||
_, attention_scale = rope_fn(config=config, device=torch_device, seq_len=1)
|
_, attention_scale = rope_fn(config=config, device=torch_device, seq_len=1)
|
||||||
self.assertEqual(attention_scale, 0.5)
|
self.assertEqual(attention_scale, 0.5)
|
||||||
|
|
||||||
|
config.rope_scaling = {
|
||||||
|
"rope_type": "longrope",
|
||||||
|
"factor": factor,
|
||||||
|
"short_factor": short_factor,
|
||||||
|
"long_factor": long_factor,
|
||||||
|
}
|
||||||
|
self.assertEqual(config.rope_scaling.get("attention_factor"), None)
|
||||||
|
# Verify that "TypeError: '<' not supported between instances of 'NoneType' and 'int'" is not raised.
|
||||||
|
rope_config_validation(config)
|
||||||
|
|
||||||
# Check 2: Factor == 1.0 -> short factor is applied to the default frequencies
|
# Check 2: Factor == 1.0 -> short factor is applied to the default frequencies
|
||||||
factor = 1.0
|
factor = 1.0
|
||||||
config.rope_scaling = {
|
config.rope_scaling = {
|
||||||
|
|||||||
Reference in New Issue
Block a user