Generate: unify LogitsWarper and LogitsProcessor (#32626)

This commit is contained in:
Joao Gante
2024-08-16 11:20:41 +01:00
committed by GitHub
parent 5fd7ca7bc9
commit 70d5df6107
20 changed files with 186 additions and 623 deletions

View File

@@ -68,14 +68,7 @@ if is_torch_available():
set_seed,
)
from transformers.generation import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
BeamSearchDecoderOnlyOutput,
BeamSearchEncoderDecoderOutput,
GenerateBeamDecoderOnlyOutput,
GenerateBeamEncoderDecoderOutput,
GenerateEncoderDecoderOutput,
PhrasalConstraint,
)
from transformers.generation.logits_process import LogitsProcessor
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder, sinusoids
@@ -419,6 +412,30 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
return False
def _get_logits_processor_kwargs(self, do_sample=False):
# Overwritten from `GenerationTesterMixin`, Whisper needs `"temperature": 0.0` to be able to do beam search
logits_processor_kwargs = super()._get_logits_processor_kwargs(do_sample=do_sample)
logits_processor_kwargs["temperature"] = 0.0
return logits_processor_kwargs
def _get_beam_kwargs(self, num_return_sequences=1):
# Overwritten from `GenerationTesterMixin`, Whisper's `num_return_sequences` differs from the core `generate`
beam_kwargs = super()._get_beam_kwargs(num_return_sequences=num_return_sequences)
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
return beam_kwargs
def _get_diverse_beam_kwargs(self, num_return_sequences=1):
# Overwritten from `GenerationTesterMixin`, Whisper's `num_return_sequences` differs from the core `generate`
beam_kwargs = super()._get_diverse_beam_kwargs(num_return_sequences=num_return_sequences)
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
return beam_kwargs
def _get_constrained_beam_kwargs(self, num_return_sequences=1):
# Overwritten from `GenerationTesterMixin`, Whisper's `num_return_sequences` differs from the core `generate`
beam_kwargs = super()._get_constrained_beam_kwargs(num_return_sequences=num_return_sequences)
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
return beam_kwargs
def setUp(self):
self.model_tester = WhisperModelTester(self)
self.config_tester = ConfigTester(self, config_class=WhisperConfig)
@@ -1551,241 +1568,6 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_longform_generate_multi_batch_cond_prev(self):
self._check_longform_generate_multi_batch(condition_on_prev_tokens=True)
def test_beam_sample_generate_dict_output(self):
# We overwrite test_beam_sample_generate_dict_output in test_utils as
# we can only perform beam search if the temperature is set to 0 in Whisper.
config, input_ids, attention_mask = self._get_input_ids_and_config()
# disable cache
config.use_cache = False
model = WhisperForConditionalGeneration(config).to(torch_device).eval()
_, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1])
beam_kwargs = self._get_beam_kwargs()
# With Whisper, we can only perform a beam search if the temperature is set to 0.
logits_warper_kwargs["temperature"] = 0
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
output_generate = self._beam_sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
beam_kwargs=beam_kwargs,
logits_warper_kwargs=logits_warper_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
else:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
self._check_outputs(output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"])
def test_beam_search_generate_dict_output(self):
# We overwrite test_beam_search_generate_dict_output in test_utils as
# we can only perform beam search if the temperature is set to 0 in Whisper.
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
# disable cache
config.use_cache = False
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
config.forced_bos_token_id,
config.forced_eos_token_id,
)
beam_kwargs = self._get_beam_kwargs()
# With Whisper, we can only perform a beam search if the temperature is set to 0.
logits_process_kwargs["temperature"] = 0
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
output_generate = self._beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self._check_outputs(
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
)
def test_beam_search_generate_dict_outputs_use_cache(self):
# We overwrite test_beam_search_generate_dict_outputs_use_cache in test_utils as
# we can only perform beam search if the temperature is set to 0 in Whisper.
for model_class in self.all_generative_model_classes:
# enable cache
config, input_ids, attention_mask = self._get_input_ids_and_config()
if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching")
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
config.forced_bos_token_id,
config.forced_eos_token_id,
)
beam_kwargs = self._get_beam_kwargs()
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
output_generate = self._beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
else:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
self._check_outputs(
output_generate, input_ids, model.config, use_cache=True, num_return_sequences=beam_kwargs["num_beams"]
)
def test_group_beam_search_generate_dict_output(self):
# We overwrite test_group_beam_search_generate_dict_output in test_utils as
# we can only perform beam search if the temperature is set to 0 in Whisper.
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = False
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
config.forced_bos_token_id,
config.forced_eos_token_id,
)
beam_kwargs = self._get_diverse_beam_kwargs()
# We will return num_beams sequences per input only if num_return_sequences == num_beams:
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
output_generate = self._group_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self._check_outputs(
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
)
def test_constrained_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
# disable cache
config.use_cache = False
model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
config.forced_bos_token_id,
config.forced_eos_token_id,
)
# Sample constraints
min_id = 3
max_id = model.config.vocab_size
force_tokens = torch.randint(min_id, max_id, (1, 2)).tolist()[0]
constraints = [
PhrasalConstraint(force_tokens),
]
beam_kwargs = self._get_constrained_beam_kwargs()
output_generate = self._constrained_beam_search_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
constraints=constraints,
beam_kwargs=beam_kwargs,
logits_process_kwargs=logits_process_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
self._check_outputs(
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_return_sequences"]
)
@is_flaky() # TODO (joao, sanchit): fails ~9% of the times. Does the original test have the same issue?
def test_custom_4d_attention_mask(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()