Remove max length beam scorer (#11378)

* removed max_len

* removed max_length from BeamSearchScorer

* correct max length

* finish

* del vim

* finish & add test

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Ashwin Geet D'Sa
2021-04-27 00:28:40 +02:00
committed by GitHub
parent bc2571e61c
commit 741d48f5c7
6 changed files with 91 additions and 38 deletions

View File

@@ -148,7 +148,6 @@ class GenerationTesterMixin:
}
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=beam_kwargs["num_beams"],
device=torch_device,
length_penalty=beam_kwargs["length_penalty"],
@@ -169,7 +168,6 @@ class GenerationTesterMixin:
}
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=beam_kwargs["num_beams"],
device=torch_device,
length_penalty=beam_kwargs["length_penalty"],
@@ -1411,7 +1409,6 @@ class GenerationIntegrationTests(unittest.TestCase):
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=torch_device,
)
@@ -1442,7 +1439,6 @@ class GenerationIntegrationTests(unittest.TestCase):
diverse_beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=torch_device,
num_beam_hyps_to_keep=num_return_sequences,
@@ -1502,7 +1498,6 @@ class GenerationIntegrationTests(unittest.TestCase):
# Beam
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=torch_device,
)
@@ -1520,7 +1515,6 @@ class GenerationIntegrationTests(unittest.TestCase):
# Grouped beam search
diverse_beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=torch_device,
num_beam_hyps_to_keep=num_return_sequences,
@@ -1535,3 +1529,51 @@ class GenerationIntegrationTests(unittest.TestCase):
max_length=max_length,
**model_kwargs,
)
def test_beam_search_warning_if_max_length_is_passed(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
batch_size = 1
num_beams = 3
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
input_ids = input_ids.expand(num_beams, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
stopping_criteria_max_length = 18
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)])
with self.assertWarns(UserWarning):
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=torch_device,
max_length=10,
)
generated_ids = bart_model.beam_search(
input_ids,
num_beams=num_beams,
stopping_criteria=stopping_criteria,
beam_scorer=beam_scorer,
**model_kwargs,
)
beam_scorer_no_max_len = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=torch_device,
)
generated_ids_no_max_len = bart_model.beam_search(
input_ids,
num_beams=num_beams,
stopping_criteria=stopping_criteria,
beam_scorer=beam_scorer_no_max_len,
**model_kwargs,
)
# BeamSearchScorer max_length should not influence "real" max_length
self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist())