RoPE: relaxed rope validation (#32182)

* relaxed rope check

* lets also accept rope_type=None, defaulting to the original implementation

* type and rope_type can coexist
This commit is contained in:
Joao Gante
2024-07-24 15:00:48 +01:00
committed by GitHub
parent 165116bc14
commit e0182f3bd7
3 changed files with 93 additions and 36 deletions

View File

@@ -526,6 +526,60 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
torch.testing.assert_close(old_cos_long, new_cos_long)
torch.testing.assert_close(old_sin_long, new_sin_long)
def test_model_loading_old_rope_configs(self):
def _reinitialize_config(base_config, new_kwargs):
# Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation
# steps.
base_config_dict = base_config.to_dict()
new_config = LlamaConfig.from_dict(config_dict={**base_config_dict, **new_kwargs})
return new_config
# from untouched config -> ✅
base_config, model_inputs = self.model_tester.prepare_config_and_inputs_for_common()
original_model = LlamaForCausalLM(base_config).to(torch_device)
original_model(**model_inputs)
# from a config with the expected rope configuration -> ✅
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0}})
original_model = LlamaForCausalLM(config).to(torch_device)
original_model(**model_inputs)
# from a config with the old rope configuration ('type' instead of 'rope_type') -> ✅ we gracefully handle BC
config = _reinitialize_config(base_config, {"rope_scaling": {"type": "linear", "factor": 10.0}})
original_model = LlamaForCausalLM(config).to(torch_device)
original_model(**model_inputs)
# from a config with both 'type' and 'rope_type' -> ✅ they can coexist (and both are present in the config)
config = _reinitialize_config(
base_config, {"rope_scaling": {"type": "linear", "rope_type": "linear", "factor": 10.0}}
)
self.assertTrue(config.rope_scaling["type"] == "linear")
self.assertTrue(config.rope_scaling["rope_type"] == "linear")
original_model = LlamaForCausalLM(config).to(torch_device)
original_model(**model_inputs)
# from a config with parameters in a bad range ('factor' should be >= 1.0) -> ⚠️ throws a warning
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": -999.0}})
original_model = LlamaForCausalLM(config).to(torch_device)
original_model(**model_inputs)
self.assertEqual(len(logs.output), 1)
self.assertIn("factor field", logs.output[0])
# from a config with unknown parameters ('foo' isn't a rope option) -> ⚠️ throws a warning
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
config = _reinitialize_config(
base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0, "foo": "bar"}}
)
original_model = LlamaForCausalLM(config).to(torch_device)
original_model(**model_inputs)
self.assertEqual(len(logs.output), 1)
self.assertIn("Unrecognized keys", logs.output[0])
# from a config with specific rope type but missing one of its mandatory parameters -> ❌ throws exception
with self.assertRaises(KeyError):
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor"
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes