Bugfix/alexsherstinsky/fix none check for attention factor in rope scaling 2024 08 28 0 (#33188)

* Fixing a bug in the way "attention_factor" is validated in ROPE utilities.

* Fixing a bug in the way "attention_factor" is validated in ROPE utilities.

* Fixing a bug in the way "attention_factor" is validated in ROPE utilities.
This commit is contained in:
Alex Sherstinsky
2024-09-04 08:01:12 -07:00
committed by GitHub
parent 178cb6bb1c
commit 122ded0a11
2 changed files with 15 additions and 4 deletions

View File

@@ -487,10 +487,11 @@ def _validate_longrope_parameters(config: PretrainedConfig):
logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
attention_factor = rope_scaling.get("attention_factor")
if attention_factor is not None and not isinstance(attention_factor, float) or attention_factor < 0:
logger.warning(
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
)
if attention_factor is not None:
if not isinstance(attention_factor, float) or attention_factor < 0.0:
logger.warning(
f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
)
def _validate_llama3_parameters(config: PretrainedConfig):

View File

@@ -330,6 +330,16 @@ class RopeTest(unittest.TestCase):
_, attention_scale = rope_fn(config=config, device=torch_device, seq_len=1)
self.assertEqual(attention_scale, 0.5)
config.rope_scaling = {
"rope_type": "longrope",
"factor": factor,
"short_factor": short_factor,
"long_factor": long_factor,
}
self.assertEqual(config.rope_scaling.get("attention_factor"), None)
# Verify that "TypeError: '<' not supported between instances of 'NoneType' and 'int'" is not raised.
rope_config_validation(config)
# Check 2: Factor == 1.0 -> short factor is applied to the default frequencies
factor = 1.0
config.rope_scaling = {