Fix convert_and_export_with_cache failures for GPU models (#38976)

* Add the `device` option for `generate()`

* Add device for default tensors to avoid tensor mismatch

* [test] Enable test_static_cache_exportability for torch_device

* infer device from the prompt_token_ids

* Add device for generated tensor

* [Test] Make `test_export_static_cache` tests to run on devices rather than only CPU

* fix format

* infer device from the model
This commit is contained in:
Stonepia
2025-07-17 21:12:32 +08:00
committed by GitHub
parent 54680d75c9
commit fc700c2a26
12 changed files with 57 additions and 27 deletions

View File

@@ -700,7 +700,7 @@ class CacheExportIntegrationTest(unittest.TestCase):
self.skipTest(reason="This test requires torch >= 2.3 to run.")
set_seed(0)
device = "cpu"
device = torch_device
dtype = "bfloat16"
cache_implementation = "static"
attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention
@@ -748,8 +748,8 @@ class CacheExportIntegrationTest(unittest.TestCase):
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)
# Export with dynamic shapes
input_ids = torch.zeros((1, 3), dtype=torch.long)
cache_position = torch.tensor([0, 1, 2], dtype=torch.long)
input_ids = torch.zeros((1, 3), dtype=torch.long, device=device)
cache_position = torch.tensor([0, 1, 2], dtype=torch.long, device=device)
dynamic_shapes = {"input_ids": {1: torch.export.Dim.DYNAMIC}, "cache_position": {0: torch.export.Dim.DYNAMIC}}
strict = version.parse(torch.__version__) != version.parse("2.7.0")
exported_program = convert_and_export_with_cache(