From 7f91f168a14c7a305252675730b3248a1c987944 Mon Sep 17 00:00:00 2001 From: jiqing-feng <107918818+jiqing-feng@users.noreply.github.com> Date: Wed, 3 Jul 2024 16:22:56 +0800 Subject: [PATCH] fix assisted decoding (#31401) * fix assisted decoding * check None * fix typo * fix _prepare_special_tokens * fix style * fix lint * add tests for assisted decoding * fix style * fix tests check --- src/transformers/generation/utils.py | 5 ++- tests/generation/test_utils.py | 50 ++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 25ec7be1b5..f99ae64fb8 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1493,8 +1493,11 @@ class GenerationMixin: device = self.device token = token_kwargs if token_kwargs is not None else token_self - if token is None or isinstance(token, torch.Tensor): + if token is None: return token + elif isinstance(token, torch.Tensor): + return token.to(device) + return torch.tensor(token, device=device, dtype=torch.long) bos_token_id = _tensor_or_none( diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index b9e962a6a1..8fa41fbdbe 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -30,7 +30,9 @@ from transformers.testing_utils import ( require_auto_gptq, require_quanto, require_torch, + require_torch_gpu, require_torch_multi_accelerator, + require_torch_multi_gpu, slow, torch_device, ) @@ -3097,6 +3099,54 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi self.assertTrue(y_prob > 0.001 and n_prob > 0.001) self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0) + @slow + @require_torch_multi_gpu + def test_assisted_decoding_in_different_gpu(self): + # PT-only test: TF doesn't support assisted decoding yet. + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda:0") + assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to( + "cuda:1" + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") + model.config.pad_token_id = tokenizer.eos_token_id + assistant.config.pad_token_id = tokenizer.eos_token_id + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + input_length = input_ids.shape[-1] + + out = model.generate( + input_ids, + assistant_model=assistant, + max_new_tokens=20, + ) + self.assertTrue(input_length <= out.shape[-1] <= input_length + 20) + + @slow + @require_torch_gpu + def test_assisted_decoding_in_gpu_cpu(self): + # PT-only test: TF doesn't support assisted decoding yet. + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda") + assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to( + "cpu" + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") + model.config.pad_token_id = tokenizer.eos_token_id + assistant.config.pad_token_id = tokenizer.eos_token_id + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + input_length = input_ids.shape[-1] + + out = model.generate( + input_ids, + assistant_model=assistant, + max_new_tokens=20, + ) + self.assertTrue(input_length <= out.shape[-1] <= input_length + 20) + @require_torch class TokenHealingTestCase(unittest.TestCase):