Cache: models return input cache type (#30716)

This commit is contained in:
Joao Gante
2024-05-08 18:26:34 +01:00
committed by GitHub
parent 71c1985069
commit f26e407370
11 changed files with 30 additions and 70 deletions

View File

@@ -881,7 +881,9 @@ class CohereModel(CoherePreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
past_seen_tokens = 0 past_seen_tokens = 0
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if cache_position is None: if cache_position is None:
@@ -943,11 +945,10 @@ class CohereModel(CoherePreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
next_cache = None next_cache = next_decoder_cache if use_cache else None
if use_cache: if return_legacy_cache:
next_cache = ( next_cache = next_cache.to_legacy_cache()
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
)
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(

View File

@@ -1115,7 +1115,9 @@ class DbrxModel(DbrxPreTrainedModel):
inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training) inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if cache_position is None: if cache_position is None:
@@ -1182,13 +1184,10 @@ class DbrxModel(DbrxPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
next_cache = None next_cache = next_decoder_cache if use_cache else None
if use_cache: if return_legacy_cache:
next_cache = ( next_cache = next_cache.to_legacy_cache()
next_decoder_cache.to_legacy_cache()
if isinstance(next_decoder_cache, DynamicCache)
else next_decoder_cache
)
if not return_dict: if not return_dict:
return tuple( return tuple(
v v

View File

@@ -865,7 +865,9 @@ class GemmaModel(GemmaPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if cache_position is None: if cache_position is None:
@@ -933,13 +935,10 @@ class GemmaModel(GemmaPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
next_cache = None next_cache = next_decoder_cache if use_cache else None
if use_cache: if return_legacy_cache:
next_cache = ( next_cache = next_cache.to_legacy_cache()
next_decoder_cache.to_legacy_cache()
if isinstance(next_decoder_cache, DynamicCache)
else next_decoder_cache
)
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(

View File

@@ -960,7 +960,9 @@ class LlamaModel(LlamaPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if cache_position is None: if cache_position is None:
@@ -1021,13 +1023,10 @@ class LlamaModel(LlamaPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
next_cache = None next_cache = next_decoder_cache if use_cache else None
if use_cache: if return_legacy_cache:
next_cache = ( next_cache = next_cache.to_legacy_cache()
next_decoder_cache.to_legacy_cache()
if isinstance(next_decoder_cache, DynamicCache)
else next_decoder_cache
)
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(

View File

@@ -938,7 +938,9 @@ class OlmoModel(OlmoPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if cache_position is None: if cache_position is None:
@@ -999,13 +1001,10 @@ class OlmoModel(OlmoPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
next_cache = None next_cache = next_decoder_cache if use_cache else None
if use_cache: if return_legacy_cache:
next_cache = ( next_cache = next_cache.to_legacy_cache()
next_decoder_cache.to_legacy_cache()
if isinstance(next_decoder_cache, DynamicCache)
else next_decoder_cache
)
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(

View File

@@ -16,8 +16,6 @@
import unittest import unittest
from parameterized import parameterized
from transformers import CohereConfig, is_torch_available from transformers import CohereConfig, is_torch_available
from transformers.testing_utils import ( from transformers.testing_utils import (
require_bitsandbytes, require_bitsandbytes,
@@ -296,11 +294,6 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
@unittest.skip("TODO @gante fix this for Cohere")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
def test_model(self): def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)

View File

@@ -17,8 +17,6 @@
import unittest import unittest
from parameterized import parameterized
from transformers import DbrxConfig, is_torch_available from transformers import DbrxConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
@@ -357,11 +355,6 @@ class DbrxModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
def test_tied_weights_keys(self): def test_tied_weights_keys(self):
pass pass
@unittest.skip("TODO @gante fix this for Llama")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
@require_torch @require_torch
class DbrxModelIntegrationTest(unittest.TestCase): class DbrxModelIntegrationTest(unittest.TestCase):

View File

@@ -17,7 +17,6 @@ import tempfile
import unittest import unittest
import pytest import pytest
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
from transformers.testing_utils import ( from transformers.testing_utils import (
@@ -367,11 +366,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
@unittest.skip("TODO @gante fix this for Llama")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
@unittest.skip("Gemma buffers include complex numbers, which breaks this test") @unittest.skip("Gemma buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass

View File

@@ -591,11 +591,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}", msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}",
) )
@unittest.skip("TODO @gante fix this for Llama")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
@require_torch_gpu @require_torch_gpu
class LlamaIntegrationTest(unittest.TestCase): class LlamaIntegrationTest(unittest.TestCase):

View File

@@ -353,11 +353,6 @@ class OlmoModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
# The output should be different for long inputs # The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
@unittest.skip("TODO @gante fix this for OLMo")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
@require_torch @require_torch
class OlmoIntegrationTest(unittest.TestCase): class OlmoIntegrationTest(unittest.TestCase):

View File

@@ -15,8 +15,6 @@
""" Testing suite for the PyTorch RecurrentGemma model. """ """ Testing suite for the PyTorch RecurrentGemma model. """
import unittest import unittest
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed
from transformers.testing_utils import ( from transformers.testing_utils import (
require_bitsandbytes, require_bitsandbytes,
@@ -330,11 +328,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
config_and_inputs[0].position_embedding_type = type config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip("Recurrent gemma does not use legacy cache")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass