Generate: fix assistant in different device (#33257)
This commit is contained in:
@@ -3323,7 +3323,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_assisted_decoding_in_gpu_cpu(self):
|
||||
def test_assisted_decoding_model_in_gpu_assistant_in_cpu(self):
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda")
|
||||
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
|
||||
|
||||
Reference in New Issue
Block a user