Ignore keys on validate_rope (#33753)

* ignore keys on check rope

* add tests

* fix tests, so maybe better leave at logger lvl
This commit is contained in:
Raushan Turganbay
2024-10-04 12:39:37 +02:00
committed by GitHub
parent 4a173b88b5
commit 061c2c4c38
3 changed files with 42 additions and 17 deletions

View File

@@ -65,6 +65,19 @@ class RopeTest(unittest.TestCase):
with self.assertRaises(KeyError):
rope_config_validation(config)
# Any other parameters passed to RoPE will raise a warning that a particular key is not used
# But sometimes we can have model-specific RoPE kwargs and bypass warning with `ignore_keys`
model_specific_kwarg = "mrope_sections" # e,g in Qwen2-VL
for rope_type in all_rope_types:
if rope_type == "default":
config.rope_scaling = {"rope_type": rope_type, model_specific_kwarg: True}
rope_config_validation(config, ignore_keys={model_specific_kwarg})
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
rope_config_validation(config)
self.assertEqual(len(logs.output), 1)
self.assertIn(model_specific_kwarg, logs.output[0])
def test_default_rope_function_bc(self):
config = LlamaConfig()
device = torch_device