Remove max length beam scorer (#11378)
* removed max_len * removed max_length from BeamSearchScorer * correct max length * finish * del vim * finish & add test Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -59,7 +59,6 @@ class BeamSearchTester:
|
||||
def prepare_beam_scorer(self, **kwargs):
|
||||
return BeamSearchScorer(
|
||||
batch_size=kwargs.get("batch_size", self.batch_size),
|
||||
max_length=kwargs.get("max_length", self.max_length),
|
||||
num_beams=kwargs.get("num_beams", self.num_beams),
|
||||
device=torch_device,
|
||||
length_penalty=kwargs.get("length_penalty", self.length_penalty),
|
||||
@@ -170,9 +169,7 @@ class BeamSearchTester:
|
||||
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
|
||||
# max_length should be only one more than current input_ids to check that eos is correctly appended
|
||||
max_length = self.sequence_length + 1
|
||||
beam_scorer = self.prepare_beam_scorer(
|
||||
num_beam_hyps_to_keep=1, max_length=max_length, length_penalty=1.0, do_early_stopping=False
|
||||
)
|
||||
beam_scorer = self.prepare_beam_scorer(num_beam_hyps_to_keep=1, length_penalty=1.0, do_early_stopping=False)
|
||||
|
||||
# update beams and append to input_ids
|
||||
tokens = next_tokens.clone()
|
||||
@@ -197,6 +194,7 @@ class BeamSearchTester:
|
||||
output_indices,
|
||||
pad_token_id=self.pad_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
sequences = sequence_output["sequences"]
|
||||
@@ -225,6 +223,7 @@ class BeamSearchTester:
|
||||
output_indices,
|
||||
pad_token_id=self.pad_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
max_length=max_length,
|
||||
)
|
||||
sequences = sequence_output["sequences"]
|
||||
sequence_scores = sequence_output["sequence_scores"]
|
||||
|
||||
@@ -148,7 +148,6 @@ class GenerationTesterMixin:
|
||||
}
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=beam_kwargs["num_beams"],
|
||||
device=torch_device,
|
||||
length_penalty=beam_kwargs["length_penalty"],
|
||||
@@ -169,7 +168,6 @@ class GenerationTesterMixin:
|
||||
}
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=beam_kwargs["num_beams"],
|
||||
device=torch_device,
|
||||
length_penalty=beam_kwargs["length_penalty"],
|
||||
@@ -1411,7 +1409,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
device=torch_device,
|
||||
)
|
||||
@@ -1442,7 +1439,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
|
||||
diverse_beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
device=torch_device,
|
||||
num_beam_hyps_to_keep=num_return_sequences,
|
||||
@@ -1502,7 +1498,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
# Beam
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
device=torch_device,
|
||||
)
|
||||
@@ -1520,7 +1515,6 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
# Grouped beam search
|
||||
diverse_beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
num_beams=num_beams,
|
||||
device=torch_device,
|
||||
num_beam_hyps_to_keep=num_return_sequences,
|
||||
@@ -1535,3 +1529,51 @@ class GenerationIntegrationTests(unittest.TestCase):
|
||||
max_length=max_length,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
def test_beam_search_warning_if_max_length_is_passed(self):
|
||||
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
|
||||
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
|
||||
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
|
||||
|
||||
batch_size = 1
|
||||
num_beams = 3
|
||||
|
||||
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||
input_ids = input_ids.expand(num_beams, -1)
|
||||
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
|
||||
|
||||
stopping_criteria_max_length = 18
|
||||
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)])
|
||||
|
||||
with self.assertWarns(UserWarning):
|
||||
beam_scorer = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
num_beams=num_beams,
|
||||
device=torch_device,
|
||||
max_length=10,
|
||||
)
|
||||
|
||||
generated_ids = bart_model.beam_search(
|
||||
input_ids,
|
||||
num_beams=num_beams,
|
||||
stopping_criteria=stopping_criteria,
|
||||
beam_scorer=beam_scorer,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
beam_scorer_no_max_len = BeamSearchScorer(
|
||||
batch_size=batch_size,
|
||||
num_beams=num_beams,
|
||||
device=torch_device,
|
||||
)
|
||||
|
||||
generated_ids_no_max_len = bart_model.beam_search(
|
||||
input_ids,
|
||||
num_beams=num_beams,
|
||||
stopping_criteria=stopping_criteria,
|
||||
beam_scorer=beam_scorer_no_max_len,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
# BeamSearchScorer max_length should not influence "real" max_length
|
||||
self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist())
|
||||
|
||||
Reference in New Issue
Block a user