Fix convert_and_export_with_cache failures for GPU models (#38976)

* Add the `device` option for `generate()`

* Add device for default tensors to avoid tensor mismatch

* [test] Enable test_static_cache_exportability for torch_device

* infer device from the prompt_token_ids

* Add device for generated tensor

* [Test] Make `test_export_static_cache` tests to run on devices rather than only CPU

* fix format

* infer device from the model
This commit is contained in:
Stonepia
2025-07-17 21:12:32 +08:00
committed by GitHub
parent 54680d75c9
commit fc700c2a26
12 changed files with 57 additions and 27 deletions

View File

@@ -423,7 +423,7 @@ class GemmaIntegrationTest(unittest.TestCase):
].shape[-1]
# Load model
device = "cpu"
device = torch_device
dtype = torch.bfloat16
cache_implementation = "static"
attn_implementation = "sdpa"