More model refactoring! (#35359)

* cohere

* style

* phi3

* style

* small fix

* small fix

* phi3 longrope

* oups

* Update rope (only for phi3 still)

* Update test_modeling_rope_utils.py

* Update modeling_phi3.py

* fix

* fix copies

* style

* Fix copied from bad renaming
This commit is contained in:
Cyril Vallez
2025-01-09 11:09:09 +01:00
committed by GitHub
parent 137965ca7d
commit 965a2fb320
36 changed files with 1253 additions and 1243 deletions

View File

@@ -459,6 +459,9 @@ class Phi3ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
"long_factor": [5.0 for _ in range(n_factors)],
}
input_tensor = ids_tensor([1, 4090], config.vocab_size)
# Make sure we don't have padding tokens. If this is the case, then the actual number of "true" tokens may be shorter
# than `config.original_max_position_embeddings + 5`, invalidating this test
input_tensor[input_tensor == config.pad_token_id] += 1
model = Phi3ForCausalLM(config)
model.to(torch_device)
model.eval()

View File

@@ -311,10 +311,10 @@ class RopeTest(unittest.TestCase):
self.assertEqual(config.rope_theta, 10000.0)
self.assertFalse(hasattr(config, "partial_rotary_factor"))
# longrope applies scaling on EACH inv frequency, `short_factor` or `long_factor`, depending on `factor`
# longrope applies scaling on EACH inv frequency, `short_factor` or `long_factor`, depending on the seq_len
dim = config.hidden_size // config.num_attention_heads
short_factor = [2.0] * (dim // 2) # scaling applied when factor == 1.0
long_factor = torch.ones(dim // 2).cumsum(0).tolist() # scaling applied when factor > 1.0
short_factor = [2.0] * (dim // 2) # scaling applied when seq_len <= max_position_embeddings
long_factor = torch.ones(dim // 2).cumsum(0).tolist() # scaling applied when seq_len > max_position_embeddings
rope_fn = ROPE_INIT_FUNCTIONS["default"]
default_inv_freq, _ = rope_fn(config=config, device=torch_device)
@@ -353,26 +353,18 @@ class RopeTest(unittest.TestCase):
# Verify that "TypeError: '<' not supported between instances of 'NoneType' and 'int'" is not raised.
rope_config_validation(config)
# Check 2: Factor == 1.0 -> short factor is applied to the default frequencies
factor = 1.0
# Check 2: seq_len == 0 -> short factor is applied to the default frequencies
config.rope_scaling = {
"rope_type": "longrope",
"factor": factor,
"factor": 1.0,
"short_factor": short_factor,
"long_factor": long_factor,
}
inv_freq, _ = rope_fn(config=config, device=torch_device)
inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=0)
torch.testing.assert_close(inv_freq, default_inv_freq / torch.tensor(short_factor).to(torch_device))
# Check 3: Factor > 1.0 -> long factor is applied to the default frequencies
factor = 10.0
config.rope_scaling = {
"rope_type": "longrope",
"factor": factor,
"short_factor": short_factor,
"long_factor": long_factor,
}
inv_freq, _ = rope_fn(config=config, device=torch_device)
# Check 3: seq_len > max_position_embeddings -> long factor is applied to the default frequencies
inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=config.max_position_embeddings + 1)
torch.testing.assert_close(inv_freq, default_inv_freq / torch.tensor(long_factor).to(torch_device))
def test_llama3_rope_numerically(self):