RoPE: relaxed rope validation (#32182)
* relaxed rope check * lets also accept rope_type=None, defaulting to the original implementation * type and rope_type can coexist
This commit is contained in:
@@ -526,6 +526,60 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
torch.testing.assert_close(old_cos_long, new_cos_long)
|
||||
torch.testing.assert_close(old_sin_long, new_sin_long)
|
||||
|
||||
def test_model_loading_old_rope_configs(self):
|
||||
def _reinitialize_config(base_config, new_kwargs):
|
||||
# Reinitialize the config with the new kwargs, forcing the config to go through its __init__ validation
|
||||
# steps.
|
||||
base_config_dict = base_config.to_dict()
|
||||
new_config = LlamaConfig.from_dict(config_dict={**base_config_dict, **new_kwargs})
|
||||
return new_config
|
||||
|
||||
# from untouched config -> ✅
|
||||
base_config, model_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
original_model = LlamaForCausalLM(base_config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
|
||||
# from a config with the expected rope configuration -> ✅
|
||||
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0}})
|
||||
original_model = LlamaForCausalLM(config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
|
||||
# from a config with the old rope configuration ('type' instead of 'rope_type') -> ✅ we gracefully handle BC
|
||||
config = _reinitialize_config(base_config, {"rope_scaling": {"type": "linear", "factor": 10.0}})
|
||||
original_model = LlamaForCausalLM(config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
|
||||
# from a config with both 'type' and 'rope_type' -> ✅ they can coexist (and both are present in the config)
|
||||
config = _reinitialize_config(
|
||||
base_config, {"rope_scaling": {"type": "linear", "rope_type": "linear", "factor": 10.0}}
|
||||
)
|
||||
self.assertTrue(config.rope_scaling["type"] == "linear")
|
||||
self.assertTrue(config.rope_scaling["rope_type"] == "linear")
|
||||
original_model = LlamaForCausalLM(config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
|
||||
# from a config with parameters in a bad range ('factor' should be >= 1.0) -> ⚠️ throws a warning
|
||||
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
|
||||
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear", "factor": -999.0}})
|
||||
original_model = LlamaForCausalLM(config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
self.assertEqual(len(logs.output), 1)
|
||||
self.assertIn("factor field", logs.output[0])
|
||||
|
||||
# from a config with unknown parameters ('foo' isn't a rope option) -> ⚠️ throws a warning
|
||||
with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
|
||||
config = _reinitialize_config(
|
||||
base_config, {"rope_scaling": {"rope_type": "linear", "factor": 10.0, "foo": "bar"}}
|
||||
)
|
||||
original_model = LlamaForCausalLM(config).to(torch_device)
|
||||
original_model(**model_inputs)
|
||||
self.assertEqual(len(logs.output), 1)
|
||||
self.assertIn("Unrecognized keys", logs.output[0])
|
||||
|
||||
# from a config with specific rope type but missing one of its mandatory parameters -> ❌ throws exception
|
||||
with self.assertRaises(KeyError):
|
||||
config = _reinitialize_config(base_config, {"rope_scaling": {"rope_type": "linear"}}) # missing "factor"
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@require_bitsandbytes
|
||||
|
||||
Reference in New Issue
Block a user