Fix passing str dtype to static cache (#33741)
Co-authored-by: Guang Yang <guangyang@fb.com>
This commit is contained in:
@@ -68,7 +68,7 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
|||||||
config=self.model.config,
|
config=self.model.config,
|
||||||
batch_size=self.model.generation_config.cache_config.batch_size,
|
batch_size=self.model.generation_config.cache_config.batch_size,
|
||||||
max_cache_len=self.model.generation_config.cache_config.max_cache_len,
|
max_cache_len=self.model.generation_config.cache_config.max_cache_len,
|
||||||
dtype=self.model.config.torch_dtype,
|
dtype=self.model.dtype,
|
||||||
)
|
)
|
||||||
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
|
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
|
||||||
if self.is_causal:
|
if self.is_causal:
|
||||||
|
|||||||
@@ -181,7 +181,7 @@ class CacheTest(unittest.TestCase):
|
|||||||
|
|
||||||
set_seed(0)
|
set_seed(0)
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
dtype = torch.float32
|
dtype = "bfloat16"
|
||||||
cache_implementation = "static"
|
cache_implementation = "static"
|
||||||
attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention
|
attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
|
|||||||
Reference in New Issue
Block a user