Fix passing str dtype to static cache (#33741)

Co-authored-by: Guang Yang <guangyang@fb.com>
This commit is contained in:
Guang Yang
2024-10-01 00:50:17 -07:00
committed by GitHub
parent c269c5c74d
commit 808997a634
2 changed files with 2 additions and 2 deletions

View File

@@ -68,7 +68,7 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
config=self.model.config, config=self.model.config,
batch_size=self.model.generation_config.cache_config.batch_size, batch_size=self.model.generation_config.cache_config.batch_size,
max_cache_len=self.model.generation_config.cache_config.max_cache_len, max_cache_len=self.model.generation_config.cache_config.max_cache_len,
dtype=self.model.config.torch_dtype, dtype=self.model.dtype,
) )
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures) self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
if self.is_causal: if self.is_causal:

View File

@@ -181,7 +181,7 @@ class CacheTest(unittest.TestCase):
set_seed(0) set_seed(0)
device = "cpu" device = "cpu"
dtype = torch.float32 dtype = "bfloat16"
cache_implementation = "static" cache_implementation = "static"
attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention
batch_size = 1 batch_size = 1