Refactor (and fix) gpt_neox (#35610)
* start a nice modular * Update modular_gpt_neox.py * Update modular_gpt_neox.py * Update modular_gpt_neox.py * Update modular_gpt_neox.py * update * Update modular_gpt_neox.py * convert * fix attribute * fix attrs * oups * fix * fix * fix * fix * fix * fix order to pass test (see with accelerate team) * trigger CIs * modular * update * up * Update test_modeling_gpt_neox.py * Update test_modeling_gpt_neox.py * trigger CIs * correctly pass arg * simplify * remove key warning * update tp -> it's compatible since the view is before * trigger CIs
This commit is contained in:
@@ -18,7 +18,7 @@ import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, GPTNeoXConfig, is_torch_available, set_seed
|
||||
from transformers import AutoTokenizer, DynamicCache, GPTNeoXConfig, is_torch_available, set_seed
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
@@ -232,13 +232,22 @@ class GPTNeoXModelTester:
|
||||
cache_inputs = {"input_ids": input_ids[:, :cached_len], "attention_mask": input_mask[:, :cached_len]}
|
||||
non_cache_inputs = {"input_ids": input_ids[:, cached_len:], "attention_mask": input_mask}
|
||||
|
||||
def copy_cache(cache: DynamicCache):
|
||||
"""Deep copy a DynamicCache to reuse the same one multiple times."""
|
||||
new_cache = cache
|
||||
for i in range(len(cache)):
|
||||
new_cache.key_cache[i] = cache.key_cache[i].clone()
|
||||
new_cache.value_cache[i] = cache.value_cache[i].clone()
|
||||
|
||||
# Cached forward once with the attention mask provided and the other time without it (which should assume full attention)
|
||||
# We need to run both on a copy of the cache, otherwise it is modified in-place
|
||||
cache_outputs = model(**cache_inputs)
|
||||
cache = cache_outputs.past_key_values
|
||||
full_outputs_with_attention_mask = model(
|
||||
**non_cache_inputs, past_key_values=cache_outputs.past_key_values
|
||||
**non_cache_inputs, past_key_values=copy_cache(cache)
|
||||
).last_hidden_state
|
||||
full_outputs_without_attention_mask = model(
|
||||
non_cache_inputs["input_ids"], past_key_values=cache_outputs.past_key_values
|
||||
non_cache_inputs["input_ids"], past_key_values=copy_cache(cache)
|
||||
).last_hidden_state
|
||||
|
||||
self.parent.assertTrue(
|
||||
|
||||
Reference in New Issue
Block a user