fix assisted decoding (#31401)
* fix assisted decoding * check None * fix typo * fix _prepare_special_tokens * fix style * fix lint * add tests for assisted decoding * fix style * fix tests check
This commit is contained in:
@@ -1493,8 +1493,11 @@ class GenerationMixin:
|
|||||||
device = self.device
|
device = self.device
|
||||||
|
|
||||||
token = token_kwargs if token_kwargs is not None else token_self
|
token = token_kwargs if token_kwargs is not None else token_self
|
||||||
if token is None or isinstance(token, torch.Tensor):
|
if token is None:
|
||||||
return token
|
return token
|
||||||
|
elif isinstance(token, torch.Tensor):
|
||||||
|
return token.to(device)
|
||||||
|
|
||||||
return torch.tensor(token, device=device, dtype=torch.long)
|
return torch.tensor(token, device=device, dtype=torch.long)
|
||||||
|
|
||||||
bos_token_id = _tensor_or_none(
|
bos_token_id = _tensor_or_none(
|
||||||
|
|||||||
@@ -30,7 +30,9 @@ from transformers.testing_utils import (
|
|||||||
require_auto_gptq,
|
require_auto_gptq,
|
||||||
require_quanto,
|
require_quanto,
|
||||||
require_torch,
|
require_torch,
|
||||||
|
require_torch_gpu,
|
||||||
require_torch_multi_accelerator,
|
require_torch_multi_accelerator,
|
||||||
|
require_torch_multi_gpu,
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
@@ -3097,6 +3099,54 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
|||||||
self.assertTrue(y_prob > 0.001 and n_prob > 0.001)
|
self.assertTrue(y_prob > 0.001 and n_prob > 0.001)
|
||||||
self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0)
|
self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
def test_assisted_decoding_in_different_gpu(self):
|
||||||
|
# PT-only test: TF doesn't support assisted decoding yet.
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda:0")
|
||||||
|
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
|
||||||
|
"cuda:1"
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
||||||
|
model.config.pad_token_id = tokenizer.eos_token_id
|
||||||
|
assistant.config.pad_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
|
text = "Hello world"
|
||||||
|
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||||
|
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||||
|
input_length = input_ids.shape[-1]
|
||||||
|
|
||||||
|
out = model.generate(
|
||||||
|
input_ids,
|
||||||
|
assistant_model=assistant,
|
||||||
|
max_new_tokens=20,
|
||||||
|
)
|
||||||
|
self.assertTrue(input_length <= out.shape[-1] <= input_length + 20)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_assisted_decoding_in_gpu_cpu(self):
|
||||||
|
# PT-only test: TF doesn't support assisted decoding yet.
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to("cuda")
|
||||||
|
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM").to(
|
||||||
|
"cpu"
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
|
||||||
|
model.config.pad_token_id = tokenizer.eos_token_id
|
||||||
|
assistant.config.pad_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
|
text = "Hello world"
|
||||||
|
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||||
|
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||||
|
input_length = input_ids.shape[-1]
|
||||||
|
|
||||||
|
out = model.generate(
|
||||||
|
input_ids,
|
||||||
|
assistant_model=assistant,
|
||||||
|
max_new_tokens=20,
|
||||||
|
)
|
||||||
|
self.assertTrue(input_length <= out.shape[-1] <= input_length + 20)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class TokenHealingTestCase(unittest.TestCase):
|
class TokenHealingTestCase(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user