More model refactoring! (#35359)
* cohere * style * phi3 * style * small fix * small fix * phi3 longrope * oups * Update rope (only for phi3 still) * Update test_modeling_rope_utils.py * Update modeling_phi3.py * fix * fix copies * style * Fix copied from bad renaming
This commit is contained in:
@@ -311,10 +311,10 @@ class RopeTest(unittest.TestCase):
|
||||
self.assertEqual(config.rope_theta, 10000.0)
|
||||
self.assertFalse(hasattr(config, "partial_rotary_factor"))
|
||||
|
||||
# longrope applies scaling on EACH inv frequency, `short_factor` or `long_factor`, depending on `factor`
|
||||
# longrope applies scaling on EACH inv frequency, `short_factor` or `long_factor`, depending on the seq_len
|
||||
dim = config.hidden_size // config.num_attention_heads
|
||||
short_factor = [2.0] * (dim // 2) # scaling applied when factor == 1.0
|
||||
long_factor = torch.ones(dim // 2).cumsum(0).tolist() # scaling applied when factor > 1.0
|
||||
short_factor = [2.0] * (dim // 2) # scaling applied when seq_len <= max_position_embeddings
|
||||
long_factor = torch.ones(dim // 2).cumsum(0).tolist() # scaling applied when seq_len > max_position_embeddings
|
||||
|
||||
rope_fn = ROPE_INIT_FUNCTIONS["default"]
|
||||
default_inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
@@ -353,26 +353,18 @@ class RopeTest(unittest.TestCase):
|
||||
# Verify that "TypeError: '<' not supported between instances of 'NoneType' and 'int'" is not raised.
|
||||
rope_config_validation(config)
|
||||
|
||||
# Check 2: Factor == 1.0 -> short factor is applied to the default frequencies
|
||||
factor = 1.0
|
||||
# Check 2: seq_len == 0 -> short factor is applied to the default frequencies
|
||||
config.rope_scaling = {
|
||||
"rope_type": "longrope",
|
||||
"factor": factor,
|
||||
"factor": 1.0,
|
||||
"short_factor": short_factor,
|
||||
"long_factor": long_factor,
|
||||
}
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=0)
|
||||
torch.testing.assert_close(inv_freq, default_inv_freq / torch.tensor(short_factor).to(torch_device))
|
||||
|
||||
# Check 3: Factor > 1.0 -> long factor is applied to the default frequencies
|
||||
factor = 10.0
|
||||
config.rope_scaling = {
|
||||
"rope_type": "longrope",
|
||||
"factor": factor,
|
||||
"short_factor": short_factor,
|
||||
"long_factor": long_factor,
|
||||
}
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device)
|
||||
# Check 3: seq_len > max_position_embeddings -> long factor is applied to the default frequencies
|
||||
inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=config.max_position_embeddings + 1)
|
||||
torch.testing.assert_close(inv_freq, default_inv_freq / torch.tensor(long_factor).to(torch_device))
|
||||
|
||||
def test_llama3_rope_numerically(self):
|
||||
|
||||
Reference in New Issue
Block a user