Fix length related warnings in speculative decoding (#29585)
* avoid generation length warning * add tests * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * add tests and minor fixes * refine `min_new_tokens` * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * add method to prepare length arguments * add test for min length * Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * fix variable naming * empty commit for tests * trigger tests (empty) --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
committed by
GitHub
parent
6cdbd73e01
commit
41579763ee
@@ -148,6 +148,11 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
||||
self.generation_config.return_dict_in_generate = True
|
||||
self.generation_config.output_scores = True
|
||||
|
||||
# avoid unnecessary warnings that min_length is larger than max_new_tokens
|
||||
self.main_model_min_length = self.generation_config.min_length
|
||||
self.generation_config.min_length = 0
|
||||
self.generation_config.min_new_tokens = None
|
||||
|
||||
def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
|
||||
"""
|
||||
Fetches the candidates to be tried for the current input.
|
||||
@@ -166,6 +171,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
||||
# Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
|
||||
new_cur_len = input_ids.shape[-1]
|
||||
max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
|
||||
min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
|
||||
if max_new_tokens == 0:
|
||||
return input_ids, None
|
||||
|
||||
@@ -186,6 +192,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
|
||||
# 2. Forecast next N tokens using the assistant model.
|
||||
assistant_generation_kwargs = {
|
||||
self.input_ids_key: input_ids,
|
||||
"min_new_tokens": min_new_tokens,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"generation_config": self.generation_config,
|
||||
"logits_processor": self.logits_processor,
|
||||
|
||||
@@ -1173,6 +1173,56 @@ class GenerationMixin:
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
def _prepare_generated_length(
|
||||
self,
|
||||
generation_config,
|
||||
has_default_max_length,
|
||||
has_default_min_length,
|
||||
model_input_name,
|
||||
input_ids_length,
|
||||
inputs_tensor,
|
||||
):
|
||||
"""Prepared max and min length in generaion configs to avoid clashes between similar attributes"""
|
||||
|
||||
if generation_config.max_new_tokens is not None:
|
||||
if not has_default_max_length and generation_config.max_length is not None:
|
||||
logger.warning(
|
||||
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
||||
"Please refer to the documentation for more information. "
|
||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
||||
)
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
|
||||
|
||||
# if both `inputs_embeds` and `input_ids` are passed, we do not correct the length
|
||||
# otherwise we need total length [inputs-embeds-len + new-tokens-len] to not go beyond indicated `max_length``
|
||||
elif (
|
||||
model_input_name == "inputs_embeds"
|
||||
and input_ids_length != inputs_tensor.shape[1]
|
||||
and not self.config.is_encoder_decoder
|
||||
):
|
||||
generation_config.max_length -= inputs_tensor.shape[1]
|
||||
|
||||
# same for min length
|
||||
if generation_config.min_new_tokens is not None:
|
||||
if not has_default_min_length:
|
||||
logger.warning(
|
||||
f"Both `min_new_tokens` (={generation_config.min_new_tokens}) and `min_length`(="
|
||||
f"{generation_config.min_length}) seem to have been set. `min_new_tokens` will take precedence. "
|
||||
"Please refer to the documentation for more information. "
|
||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
||||
)
|
||||
generation_config.min_length = generation_config.min_new_tokens + input_ids_length
|
||||
|
||||
elif (
|
||||
model_input_name == "inputs_embeds"
|
||||
and input_ids_length != inputs_tensor.shape[1]
|
||||
and not self.config.is_encoder_decoder
|
||||
):
|
||||
generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0)
|
||||
|
||||
return generation_config
|
||||
|
||||
def _prepare_generation_config(
|
||||
self, generation_config: GenerationConfig, **kwargs: Dict
|
||||
) -> Tuple[GenerationConfig, Dict]:
|
||||
@@ -1418,24 +1468,15 @@ class GenerationMixin:
|
||||
# 6. Prepare `max_length` depending on other stopping criteria.
|
||||
input_ids_length = input_ids.shape[-1]
|
||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||
if generation_config.max_new_tokens is not None:
|
||||
if not has_default_max_length and generation_config.max_length is not None:
|
||||
logger.warning(
|
||||
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
||||
"Please refer to the documentation for more information. "
|
||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
||||
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
|
||||
generation_config = self._prepare_generated_length(
|
||||
generation_config=generation_config,
|
||||
has_default_max_length=has_default_max_length,
|
||||
has_default_min_length=has_default_min_length,
|
||||
model_input_name=model_input_name,
|
||||
inputs_tensor=inputs_tensor,
|
||||
input_ids_length=input_ids_length,
|
||||
)
|
||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
|
||||
|
||||
# otherwise the total length [inputs-embeds-len + new-tokens-len] will go beyond indicated `max_length``
|
||||
elif (
|
||||
model_input_name == "inputs_embeds"
|
||||
and inputs_tensor.shape[:-1] != input_ids.shape
|
||||
and not self.config.is_encoder_decoder
|
||||
):
|
||||
generation_config.max_length -= inputs_tensor.shape[1]
|
||||
generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0)
|
||||
|
||||
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
|
||||
if generation_config.cache_implementation == "static":
|
||||
@@ -1511,7 +1552,7 @@ class GenerationMixin:
|
||||
)
|
||||
|
||||
# 12. run assisted generate
|
||||
result = self.assisted_decoding(
|
||||
result = self._assisted_decoding(
|
||||
input_ids,
|
||||
candidate_generator=candidate_generator,
|
||||
do_sample=generation_config.do_sample,
|
||||
|
||||
@@ -1977,6 +1977,20 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, max_length=max_length)
|
||||
self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1])
|
||||
|
||||
def test_min_length_if_input_embeds(self):
|
||||
# PT-only test: TF doesn't have StoppingCriteria
|
||||
article = "Today a dragon flew over Paris."
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||
|
||||
min_length = 10
|
||||
input_len = input_ids.shape[-1]
|
||||
out_gen = model.generate(input_ids=input_ids, min_length=min_length)
|
||||
out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, min_length=min_length)
|
||||
self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1])
|
||||
|
||||
def test_custom_stopping_criteria_overload_error(self):
|
||||
# PT-only test: TF doesn't have StoppingCriteria
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
@@ -2539,6 +2553,56 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
|
||||
model.generate(input_ids)
|
||||
self.assertEqual(len(warning_list), 0)
|
||||
|
||||
def test_length_warning_assisted_generation(self):
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model.config.pad_token_id = tokenizer.eos_token_id
|
||||
assistant.config.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
text = "Hello world"
|
||||
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||
|
||||
# This should not raise any warning that min length is not feasible in candidate generation
|
||||
with warnings.catch_warnings(record=True) as warning_list:
|
||||
model.generate(
|
||||
input_ids,
|
||||
assistant_model=assistant,
|
||||
min_new_tokens=10,
|
||||
max_length=20,
|
||||
)
|
||||
self.assertEqual(len(warning_list), 0)
|
||||
|
||||
def test_generated_length_assisted_generation(self):
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
model.config.pad_token_id = tokenizer.eos_token_id
|
||||
assistant.config.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
text = "Hello world"
|
||||
tokenized_inputs = tokenizer([text], return_tensors="pt")
|
||||
input_ids = tokenized_inputs.input_ids.to(torch_device)
|
||||
input_length = input_ids.shape[-1]
|
||||
|
||||
out = model.generate(
|
||||
input_ids,
|
||||
assistant_model=assistant,
|
||||
min_new_tokens=10,
|
||||
max_new_tokens=20,
|
||||
)
|
||||
self.assertTrue((10 + input_length) <= out.shape[-1] <= (20 + input_length))
|
||||
|
||||
out = model.generate(
|
||||
input_ids,
|
||||
assistant_model=assistant,
|
||||
min_new_tokens=10,
|
||||
)
|
||||
self.assertTrue((input_length + 10) <= out.shape[-1] <= 20)
|
||||
|
||||
def test_model_kwarg_assisted_decoding_decoder_only(self):
|
||||
# PT-only test: TF doesn't support assisted decoding yet.
|
||||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user