Refactor (and fix) gpt_neox (#35610)

* start a nice modular

* Update modular_gpt_neox.py

* Update modular_gpt_neox.py

* Update modular_gpt_neox.py

* Update modular_gpt_neox.py

* update

* Update modular_gpt_neox.py

* convert

* fix attribute

* fix attrs

* oups

* fix

* fix

* fix

* fix

* fix

* fix order to pass test (see with accelerate team)

* trigger CIs

* modular

* update

* up

* Update test_modeling_gpt_neox.py

* Update test_modeling_gpt_neox.py

* trigger CIs

* correctly pass arg

* simplify

* remove key warning

* update tp -> it's compatible since the view is before

* trigger CIs
This commit is contained in:
Cyril Vallez
2025-02-04 11:18:43 +01:00
committed by GitHub
parent ad30598923
commit 9afb904b15
6 changed files with 1159 additions and 608 deletions

View File

@@ -18,7 +18,7 @@ import unittest
from parameterized import parameterized
from transformers import AutoTokenizer, GPTNeoXConfig, is_torch_available, set_seed
from transformers import AutoTokenizer, DynamicCache, GPTNeoXConfig, is_torch_available, set_seed
from transformers.testing_utils import require_torch, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
@@ -232,13 +232,22 @@ class GPTNeoXModelTester:
cache_inputs = {"input_ids": input_ids[:, :cached_len], "attention_mask": input_mask[:, :cached_len]}
non_cache_inputs = {"input_ids": input_ids[:, cached_len:], "attention_mask": input_mask}
def copy_cache(cache: DynamicCache):
"""Deep copy a DynamicCache to reuse the same one multiple times."""
new_cache = cache
for i in range(len(cache)):
new_cache.key_cache[i] = cache.key_cache[i].clone()
new_cache.value_cache[i] = cache.value_cache[i].clone()
# Cached forward once with the attention mask provided and the other time without it (which should assume full attention)
# We need to run both on a copy of the cache, otherwise it is modified in-place
cache_outputs = model(**cache_inputs)
cache = cache_outputs.past_key_values
full_outputs_with_attention_mask = model(
**non_cache_inputs, past_key_values=cache_outputs.past_key_values
**non_cache_inputs, past_key_values=copy_cache(cache)
).last_hidden_state
full_outputs_without_attention_mask = model(
non_cache_inputs["input_ids"], past_key_values=cache_outputs.past_key_values
non_cache_inputs["input_ids"], past_key_values=copy_cache(cache)
).last_hidden_state
self.parent.assertTrue(