Generate: fix assistant in different device (#33257)

This commit is contained in:
Joao Gante
2024-09-02 14:37:49 +01:00
committed by GitHub
parent 52a0213755
commit 97c0f45b9c
2 changed files with 2 additions and 1 deletions

View File

@@ -3323,7 +3323,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
@slow
@require_torch_gpu
def test_assisted_decoding_in_gpu_cpu(self):
def test_assisted_decoding_model_in_gpu_assistant_in_cpu(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda")
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(