From 808997a6343c46fbd7e6c92d507805d51750baf4 Mon Sep 17 00:00:00 2001 From: Guang Yang <42389959+guangy10@users.noreply.github.com> Date: Tue, 1 Oct 2024 00:50:17 -0700 Subject: [PATCH] Fix passing str dtype to static cache (#33741) Co-authored-by: Guang Yang --- src/transformers/integrations/executorch.py | 2 +- tests/utils/test_cache_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index afcba5ebd0..4adfcc39a4 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -68,7 +68,7 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module): config=self.model.config, batch_size=self.model.generation_config.cache_config.batch_size, max_cache_len=self.model.generation_config.cache_config.max_cache_len, - dtype=self.model.config.torch_dtype, + dtype=self.model.dtype, ) self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures) if self.is_causal: diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index bd1b8be512..8392ed18b7 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -181,7 +181,7 @@ class CacheTest(unittest.TestCase): set_seed(0) device = "cpu" - dtype = torch.float32 + dtype = "bfloat16" cache_implementation = "static" attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention batch_size = 1