aligned sample_beam output selection with beam_search (#25375)
* aligned sample_beam specs with beam_search * pull origin main * Revert "pull origin main" This reverts commit 06d356f1137bb52272e120a03636598c44449cf3. * update test_utils.py * fix format * remove comment --------- Co-authored-by: Shogo Fujita <shogo.fujita@legalontech.jp>
This commit is contained in:
@@ -1691,18 +1691,19 @@ class GenerationMixin:
|
|||||||
|
|
||||||
# 12. prepare beam search scorer
|
# 12. prepare beam search scorer
|
||||||
beam_scorer = BeamSearchScorer(
|
beam_scorer = BeamSearchScorer(
|
||||||
batch_size=batch_size * generation_config.num_return_sequences,
|
batch_size=batch_size,
|
||||||
num_beams=generation_config.num_beams,
|
num_beams=generation_config.num_beams,
|
||||||
device=inputs_tensor.device,
|
device=inputs_tensor.device,
|
||||||
length_penalty=generation_config.length_penalty,
|
length_penalty=generation_config.length_penalty,
|
||||||
do_early_stopping=generation_config.early_stopping,
|
do_early_stopping=generation_config.early_stopping,
|
||||||
|
num_beam_hyps_to_keep=generation_config.num_return_sequences,
|
||||||
max_length=generation_config.max_length,
|
max_length=generation_config.max_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 13. interleave input_ids with `num_beams` additional sequences per batch
|
# 13. interleave input_ids with `num_beams` additional sequences per batch
|
||||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
expand_size=generation_config.num_beams * generation_config.num_return_sequences,
|
expand_size=generation_config.num_beams,
|
||||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -438,7 +438,6 @@ class GenerationTesterMixin:
|
|||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
max_length,
|
max_length,
|
||||||
num_return_sequences,
|
|
||||||
beam_scorer,
|
beam_scorer,
|
||||||
beam_kwargs,
|
beam_kwargs,
|
||||||
logits_warper,
|
logits_warper,
|
||||||
@@ -463,7 +462,7 @@ class GenerationTesterMixin:
|
|||||||
**logits_warper_kwargs,
|
**logits_warper_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
# beam_search does not automatically interleave `batch_size` dim for `num_beams * num_return_sequences`
|
# beam_search does not automatically interleave `batch_size` dim for `num_beams`
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
@@ -471,13 +470,13 @@ class GenerationTesterMixin:
|
|||||||
model,
|
model,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
num_interleave=beam_scorer.num_beams * num_return_sequences,
|
num_interleave=beam_scorer.num_beams,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
)
|
)
|
||||||
kwargs["encoder_outputs"] = encoder_outputs
|
kwargs["encoder_outputs"] = encoder_outputs
|
||||||
elif attention_mask is not None:
|
elif attention_mask is not None:
|
||||||
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0)
|
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
|
||||||
|
|
||||||
# prevent flaky generation test failures
|
# prevent flaky generation test failures
|
||||||
logits_processor = LogitsProcessorList()
|
logits_processor = LogitsProcessorList()
|
||||||
@@ -486,7 +485,7 @@ class GenerationTesterMixin:
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||||
output_beam_sample = model.beam_sample(
|
output_beam_sample = model.beam_sample(
|
||||||
input_ids.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0),
|
input_ids.repeat_interleave(beam_scorer.num_beams, dim=0),
|
||||||
beam_scorer,
|
beam_scorer,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
logits_warper=logits_warper,
|
logits_warper=logits_warper,
|
||||||
@@ -891,13 +890,9 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
self.assertListEqual(output_generate.tolist(), output_beam_search.tolist())
|
self.assertListEqual(output_generate.tolist(), output_beam_search.tolist())
|
||||||
|
|
||||||
# check `generate()` and `beam_search()` are equal for `num_return_sequences`
|
|
||||||
num_return_sequences = 2
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
max_length = 4
|
max_length = 4
|
||||||
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(
|
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
|
||||||
input_ids.shape[0], max_length, num_return_sequences=num_return_sequences
|
|
||||||
)
|
|
||||||
|
|
||||||
output_generate, output_beam_search = self._beam_search_generate(
|
output_generate, output_beam_search = self._beam_search_generate(
|
||||||
model=model,
|
model=model,
|
||||||
@@ -1036,21 +1031,15 @@ class GenerationTesterMixin:
|
|||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
# check `generate()` and `beam_search()` are equal
|
# check `generate()` and `beam_search()` are equal
|
||||||
# change `num_return_sequences = 2` but not for `beam_scorer`
|
|
||||||
num_return_sequences = 2
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
max_length = 4
|
max_length = 4
|
||||||
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(
|
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
|
||||||
input_ids.shape[0] * num_return_sequences, max_length
|
|
||||||
)
|
|
||||||
beam_kwargs["num_return_sequences"] = num_return_sequences
|
|
||||||
|
|
||||||
output_generate, output_beam_sample = self._beam_sample_generate(
|
output_generate, output_beam_sample = self._beam_sample_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
num_return_sequences=num_return_sequences,
|
|
||||||
beam_scorer=beam_scorer,
|
beam_scorer=beam_scorer,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
logits_warper=logits_warper,
|
logits_warper=logits_warper,
|
||||||
@@ -1074,20 +1063,15 @@ class GenerationTesterMixin:
|
|||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
|
||||||
|
|
||||||
num_return_sequences = 2
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
max_length = 4
|
max_length = 4
|
||||||
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(
|
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
|
||||||
input_ids.shape[0] * num_return_sequences, max_length
|
|
||||||
)
|
|
||||||
beam_kwargs["num_return_sequences"] = num_return_sequences
|
|
||||||
|
|
||||||
output_beam_sample, output_generate = self._beam_sample_generate(
|
output_beam_sample, output_generate = self._beam_sample_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
num_return_sequences=num_return_sequences,
|
|
||||||
beam_scorer=beam_scorer,
|
beam_scorer=beam_scorer,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
logits_warper=logits_warper,
|
logits_warper=logits_warper,
|
||||||
@@ -1113,9 +1097,7 @@ class GenerationTesterMixin:
|
|||||||
self.assertTrue((output_generate["sequences_scores"] < 0).all().item())
|
self.assertTrue((output_generate["sequences_scores"] < 0).all().item())
|
||||||
|
|
||||||
for output in (output_beam_sample, output_generate):
|
for output in (output_beam_sample, output_generate):
|
||||||
self._check_outputs(
|
self._check_outputs(output, input_ids, model.config, num_return_sequences=beam_scorer.num_beams)
|
||||||
output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_generate_without_input_ids(self):
|
def test_generate_without_input_ids(self):
|
||||||
config, _, _, max_length = self._get_input_ids_and_config()
|
config, _, _, max_length = self._get_input_ids_and_config()
|
||||||
|
|||||||
Reference in New Issue
Block a user