Generate: fix assistant in different device (#33257)
This commit is contained in:
@@ -3964,6 +3964,7 @@ class GenerationMixin:
|
|||||||
|
|
||||||
# 1. Fetch candidate sequences from a `CandidateGenerator`
|
# 1. Fetch candidate sequences from a `CandidateGenerator`
|
||||||
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
|
candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
|
||||||
|
candidate_input_ids = candidate_input_ids.to(self.device)
|
||||||
if candidate_logits is not None:
|
if candidate_logits is not None:
|
||||||
candidate_logits = candidate_logits.to(self.device)
|
candidate_logits = candidate_logits.to(self.device)
|
||||||
|
|
||||||
|
|||||||
@@ -3323,7 +3323,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
def test_assisted_decoding_in_gpu_cpu(self):
|
def test_assisted_decoding_model_in_gpu_assistant_in_cpu(self):
|
||||||
# PT-only test: TF doesn't support assisted decoding yet.
|
# PT-only test: TF doesn't support assisted decoding yet.
|
||||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda")
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda")
|
||||||
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
|
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
|
||||||
|
|||||||
Reference in New Issue
Block a user