Generate: unify LogitsWarper and LogitsProcessor (#32626)
This commit is contained in:
@@ -118,26 +118,24 @@ class GenerationTesterMixin:
|
||||
|
||||
return config, input_ids, attention_mask
|
||||
|
||||
@staticmethod
|
||||
def _get_logits_processor_and_warper_kwargs(
|
||||
input_length,
|
||||
forced_bos_token_id=None,
|
||||
forced_eos_token_id=None,
|
||||
):
|
||||
process_kwargs = {
|
||||
def _get_logits_processor_kwargs(self, do_sample=False):
|
||||
logits_processor_kwargs = {
|
||||
"bad_words_ids": [[1, 0]],
|
||||
"repetition_penalty": 1.2,
|
||||
"remove_invalid_values": True,
|
||||
}
|
||||
# NoRepeatNGramLogitsProcessor + forced tokens may result in no valid continuations
|
||||
if forced_bos_token_id is None and forced_eos_token_id is None:
|
||||
process_kwargs["no_repeat_ngram_size"] = 2
|
||||
if do_sample:
|
||||
logits_processor_kwargs.update(
|
||||
{
|
||||
"top_k": 10,
|
||||
"top_p": 0.7,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
)
|
||||
|
||||
warp_kwargs = {"top_k": 10, "top_p": 0.7, "temperature": 0.7}
|
||||
return process_kwargs, warp_kwargs
|
||||
return logits_processor_kwargs
|
||||
|
||||
@staticmethod
|
||||
def _get_beam_kwargs(num_return_sequences=1):
|
||||
def _get_beam_kwargs(self, num_return_sequences=1):
|
||||
beam_kwargs = {
|
||||
"early_stopping": False,
|
||||
"length_penalty": 2.0,
|
||||
@@ -146,8 +144,7 @@ class GenerationTesterMixin:
|
||||
}
|
||||
return beam_kwargs
|
||||
|
||||
@staticmethod
|
||||
def _get_diverse_beam_kwargs(num_return_sequences=1):
|
||||
def _get_diverse_beam_kwargs(self, num_return_sequences=1):
|
||||
beam_kwargs = {
|
||||
"early_stopping": False,
|
||||
"length_penalty": 2.0,
|
||||
@@ -158,8 +155,7 @@ class GenerationTesterMixin:
|
||||
}
|
||||
return beam_kwargs
|
||||
|
||||
@staticmethod
|
||||
def _get_constrained_beam_kwargs(num_return_sequences=1):
|
||||
def _get_constrained_beam_kwargs(self, num_return_sequences=1):
|
||||
beam_kwargs = {
|
||||
"early_stopping": False,
|
||||
"length_penalty": 2.0,
|
||||
@@ -199,12 +195,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
||||
)
|
||||
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@@ -216,7 +207,7 @@ class GenerationTesterMixin:
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**logits_process_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@@ -228,8 +219,6 @@ class GenerationTesterMixin:
|
||||
input_ids,
|
||||
attention_mask,
|
||||
num_return_sequences,
|
||||
logits_warper_kwargs,
|
||||
process_kwargs,
|
||||
output_scores=False,
|
||||
output_logits=False,
|
||||
output_attentions=False,
|
||||
@@ -237,6 +226,7 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@@ -249,8 +239,7 @@ class GenerationTesterMixin:
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**logits_warper_kwargs,
|
||||
**process_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@@ -262,13 +251,13 @@ class GenerationTesterMixin:
|
||||
input_ids,
|
||||
attention_mask,
|
||||
beam_kwargs,
|
||||
logits_process_kwargs,
|
||||
output_scores=False,
|
||||
output_logits=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@@ -280,7 +269,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**beam_kwargs,
|
||||
**logits_process_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@@ -292,7 +281,6 @@ class GenerationTesterMixin:
|
||||
input_ids,
|
||||
attention_mask,
|
||||
beam_kwargs,
|
||||
logits_warper_kwargs,
|
||||
output_scores=False,
|
||||
output_logits=False,
|
||||
output_attentions=False,
|
||||
@@ -300,6 +288,7 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
torch.manual_seed(0)
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=True)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@@ -311,7 +300,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**beam_kwargs,
|
||||
**logits_warper_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@@ -323,13 +312,13 @@ class GenerationTesterMixin:
|
||||
input_ids,
|
||||
attention_mask,
|
||||
beam_kwargs,
|
||||
logits_process_kwargs,
|
||||
output_scores=False,
|
||||
output_logits=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@@ -341,7 +330,7 @@ class GenerationTesterMixin:
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**beam_kwargs,
|
||||
**logits_process_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@@ -354,13 +343,13 @@ class GenerationTesterMixin:
|
||||
attention_mask,
|
||||
constraints,
|
||||
beam_kwargs,
|
||||
logits_process_kwargs,
|
||||
output_scores=False,
|
||||
output_logits=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict_in_generate=False,
|
||||
):
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@@ -373,7 +362,7 @@ class GenerationTesterMixin:
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
constraints=constraints,
|
||||
**beam_kwargs,
|
||||
**logits_process_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@@ -395,12 +384,7 @@ class GenerationTesterMixin:
|
||||
"top_k": 5,
|
||||
}
|
||||
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
||||
)
|
||||
|
||||
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False)
|
||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||
output_generate = model.generate(
|
||||
input_ids,
|
||||
@@ -412,7 +396,7 @@ class GenerationTesterMixin:
|
||||
output_scores=output_scores,
|
||||
output_logits=output_logits,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**logits_process_kwargs,
|
||||
**logits_processor_kwargs,
|
||||
**model_kwargs,
|
||||
**contrastive_search_kwargs,
|
||||
)
|
||||
@@ -495,19 +479,11 @@ class GenerationTesterMixin:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
||||
)
|
||||
|
||||
output_generate = self._sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
num_return_sequences=1,
|
||||
logits_warper_kwargs=logits_warper_kwargs,
|
||||
process_kwargs=process_kwargs,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
@@ -521,20 +497,11 @@ class GenerationTesterMixin:
|
||||
|
||||
config.use_cache = False
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
||||
)
|
||||
|
||||
output_generate = self._sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
num_return_sequences=2,
|
||||
logits_warper_kwargs=logits_warper_kwargs,
|
||||
process_kwargs=process_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
@@ -561,19 +528,12 @@ class GenerationTesterMixin:
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
)
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
|
||||
output_generate = self._beam_search_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
@@ -589,18 +549,12 @@ class GenerationTesterMixin:
|
||||
config.use_cache = False
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
)
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
output_generate = self._beam_search_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
@@ -633,12 +587,6 @@ class GenerationTesterMixin:
|
||||
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
)
|
||||
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
|
||||
config.use_cache = True
|
||||
@@ -649,7 +597,6 @@ class GenerationTesterMixin:
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
@@ -693,17 +640,13 @@ class GenerationTesterMixin:
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
_, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1])
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
|
||||
output_generate = self._beam_sample_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_warper_kwargs=logits_warper_kwargs,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
@@ -711,7 +654,13 @@ class GenerationTesterMixin:
|
||||
else:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||
|
||||
if "inputs_embeds" in set(inspect.signature(model.prepare_inputs_for_generation).parameters):
|
||||
prepare_inputs_for_generation_args = set(inspect.signature(model.prepare_inputs_for_generation).parameters)
|
||||
# `inputs_embeds` input is well supported when `cache_positions` is used, because it means the modeling
|
||||
# code is up to date with our most recent standards
|
||||
if (
|
||||
"inputs_embeds" in prepare_inputs_for_generation_args
|
||||
and "cache_positions" in prepare_inputs_for_generation_args
|
||||
):
|
||||
input_embeds = model.get_input_embeddings()(input_ids)
|
||||
beam_kwargs.update({"inputs_embeds": input_embeds})
|
||||
output_generate2 = self._beam_sample_generate(
|
||||
@@ -719,7 +668,6 @@ class GenerationTesterMixin:
|
||||
input_ids=None,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_warper_kwargs=logits_warper_kwargs,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output_generate[:, input_embeds.shape[1] :], output_generate2)
|
||||
@@ -732,7 +680,6 @@ class GenerationTesterMixin:
|
||||
config.use_cache = False
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
_, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1])
|
||||
beam_kwargs = self._get_beam_kwargs()
|
||||
|
||||
output_generate = self._beam_sample_generate(
|
||||
@@ -740,7 +687,6 @@ class GenerationTesterMixin:
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_warper_kwargs=logits_warper_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
@@ -788,12 +734,6 @@ class GenerationTesterMixin:
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
)
|
||||
|
||||
# check `generate()` and `group_beam_search()` are equal
|
||||
beam_kwargs = self._get_diverse_beam_kwargs()
|
||||
output_generate = self._group_beam_search_generate(
|
||||
@@ -801,7 +741,6 @@ class GenerationTesterMixin:
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
@@ -816,7 +755,6 @@ class GenerationTesterMixin:
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||
@@ -829,19 +767,12 @@ class GenerationTesterMixin:
|
||||
config.use_cache = False
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
)
|
||||
|
||||
beam_kwargs = self._get_diverse_beam_kwargs()
|
||||
output_generate = self._group_beam_search_generate(
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
@@ -871,12 +802,6 @@ class GenerationTesterMixin:
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
)
|
||||
|
||||
# Sample constraints
|
||||
min_id = 3
|
||||
max_id = config.vocab_size
|
||||
@@ -893,7 +818,6 @@ class GenerationTesterMixin:
|
||||
attention_mask=attention_mask,
|
||||
constraints=constraints,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
@@ -919,7 +843,6 @@ class GenerationTesterMixin:
|
||||
attention_mask=attention_mask,
|
||||
constraints=constraints,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
)
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
@@ -938,11 +861,6 @@ class GenerationTesterMixin:
|
||||
config.use_cache = False
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||
input_ids.shape[-1],
|
||||
config.forced_bos_token_id,
|
||||
config.forced_eos_token_id,
|
||||
)
|
||||
|
||||
# Sample constraints
|
||||
min_id = 3
|
||||
@@ -959,7 +877,6 @@ class GenerationTesterMixin:
|
||||
attention_mask=attention_mask,
|
||||
constraints=constraints,
|
||||
beam_kwargs=beam_kwargs,
|
||||
logits_process_kwargs=logits_process_kwargs,
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
|
||||
Reference in New Issue
Block a user