Cache: models return input cache type (#30716)
This commit is contained in:
@@ -881,7 +881,9 @@ class CohereModel(CoherePreTrainedModel):
|
|||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
past_seen_tokens = 0
|
past_seen_tokens = 0
|
||||||
|
return_legacy_cache = False
|
||||||
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
|
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
|
return_legacy_cache = True
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
@@ -943,11 +945,10 @@ class CohereModel(CoherePreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
next_cache = None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
if use_cache:
|
if return_legacy_cache:
|
||||||
next_cache = (
|
next_cache = next_cache.to_legacy_cache()
|
||||||
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
|
||||||
)
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
|
|||||||
@@ -1115,7 +1115,9 @@ class DbrxModel(DbrxPreTrainedModel):
|
|||||||
|
|
||||||
inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training)
|
inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training)
|
||||||
|
|
||||||
|
return_legacy_cache = False
|
||||||
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
|
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
|
return_legacy_cache = True
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
@@ -1182,13 +1184,10 @@ class DbrxModel(DbrxPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
next_cache = None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
if use_cache:
|
if return_legacy_cache:
|
||||||
next_cache = (
|
next_cache = next_cache.to_legacy_cache()
|
||||||
next_decoder_cache.to_legacy_cache()
|
|
||||||
if isinstance(next_decoder_cache, DynamicCache)
|
|
||||||
else next_decoder_cache
|
|
||||||
)
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
v
|
v
|
||||||
|
|||||||
@@ -865,7 +865,9 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
return_legacy_cache = False
|
||||||
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
|
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
|
return_legacy_cache = True
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
@@ -933,13 +935,10 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
next_cache = None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
if use_cache:
|
if return_legacy_cache:
|
||||||
next_cache = (
|
next_cache = next_cache.to_legacy_cache()
|
||||||
next_decoder_cache.to_legacy_cache()
|
|
||||||
if isinstance(next_decoder_cache, DynamicCache)
|
|
||||||
else next_decoder_cache
|
|
||||||
)
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
|
|||||||
@@ -960,7 +960,9 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
return_legacy_cache = False
|
||||||
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
|
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
|
return_legacy_cache = True
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
@@ -1021,13 +1023,10 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
next_cache = None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
if use_cache:
|
if return_legacy_cache:
|
||||||
next_cache = (
|
next_cache = next_cache.to_legacy_cache()
|
||||||
next_decoder_cache.to_legacy_cache()
|
|
||||||
if isinstance(next_decoder_cache, DynamicCache)
|
|
||||||
else next_decoder_cache
|
|
||||||
)
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
|
|||||||
@@ -938,7 +938,9 @@ class OlmoModel(OlmoPreTrainedModel):
|
|||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
return_legacy_cache = False
|
||||||
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
|
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
|
||||||
|
return_legacy_cache = True
|
||||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
|
||||||
if cache_position is None:
|
if cache_position is None:
|
||||||
@@ -999,13 +1001,10 @@ class OlmoModel(OlmoPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
next_cache = None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
if use_cache:
|
if return_legacy_cache:
|
||||||
next_cache = (
|
next_cache = next_cache.to_legacy_cache()
|
||||||
next_decoder_cache.to_legacy_cache()
|
|
||||||
if isinstance(next_decoder_cache, DynamicCache)
|
|
||||||
else next_decoder_cache
|
|
||||||
)
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||||
return BaseModelOutputWithPast(
|
return BaseModelOutputWithPast(
|
||||||
|
|||||||
@@ -16,8 +16,6 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from parameterized import parameterized
|
|
||||||
|
|
||||||
from transformers import CohereConfig, is_torch_available
|
from transformers import CohereConfig, is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
@@ -296,11 +294,6 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
def test_config(self):
|
def test_config(self):
|
||||||
self.config_tester.run_common_tests()
|
self.config_tester.run_common_tests()
|
||||||
|
|
||||||
@unittest.skip("TODO @gante fix this for Cohere")
|
|
||||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
|
||||||
def test_new_cache_format(self, num_beams, do_sample):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_model(self):
|
def test_model(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|||||||
@@ -17,8 +17,6 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from parameterized import parameterized
|
|
||||||
|
|
||||||
from transformers import DbrxConfig, is_torch_available
|
from transformers import DbrxConfig, is_torch_available
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
@@ -357,11 +355,6 @@ class DbrxModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
def test_tied_weights_keys(self):
|
def test_tied_weights_keys(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip("TODO @gante fix this for Llama")
|
|
||||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
|
||||||
def test_new_cache_format(self, num_beams, do_sample):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class DbrxModelIntegrationTest(unittest.TestCase):
|
class DbrxModelIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from parameterized import parameterized
|
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
@@ -367,11 +366,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||||
|
|
||||||
@unittest.skip("TODO @gante fix this for Llama")
|
|
||||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
|
||||||
def test_new_cache_format(self, num_beams, do_sample):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@unittest.skip("Gemma buffers include complex numbers, which breaks this test")
|
@unittest.skip("Gemma buffers include complex numbers, which breaks this test")
|
||||||
def test_save_load_fast_init_from_base(self):
|
def test_save_load_fast_init_from_base(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -591,11 +591,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}",
|
msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}",
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skip("TODO @gante fix this for Llama")
|
|
||||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
|
||||||
def test_new_cache_format(self, num_beams, do_sample):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
class LlamaIntegrationTest(unittest.TestCase):
|
class LlamaIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -353,11 +353,6 @@ class OlmoModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
|||||||
# The output should be different for long inputs
|
# The output should be different for long inputs
|
||||||
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
||||||
|
|
||||||
@unittest.skip("TODO @gante fix this for OLMo")
|
|
||||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
|
||||||
def test_new_cache_format(self, num_beams, do_sample):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class OlmoIntegrationTest(unittest.TestCase):
|
class OlmoIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -15,8 +15,6 @@
|
|||||||
""" Testing suite for the PyTorch RecurrentGemma model. """
|
""" Testing suite for the PyTorch RecurrentGemma model. """
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from parameterized import parameterized
|
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed
|
from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_bitsandbytes,
|
require_bitsandbytes,
|
||||||
@@ -330,11 +328,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
|||||||
config_and_inputs[0].position_embedding_type = type
|
config_and_inputs[0].position_embedding_type = type
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
@unittest.skip("Recurrent gemma does not use legacy cache")
|
|
||||||
@parameterized.expand([(1, False), (1, True), (4, False)])
|
|
||||||
def test_new_cache_format(self, num_beams, do_sample):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_save_load_fast_init_from_base(self):
|
def test_save_load_fast_init_from_base(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user