Remove max length beam scorer (#11378)

* removed max_len

* removed max_length from BeamSearchScorer

* correct max length

* finish

* del vim

* finish & add test

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Ashwin Geet D'Sa
2021-04-27 00:28:40 +02:00
committed by GitHub
parent bc2571e61c
commit 741d48f5c7
6 changed files with 91 additions and 38 deletions

View File

@@ -59,7 +59,6 @@ class BeamSearchTester:
def prepare_beam_scorer(self, **kwargs):
return BeamSearchScorer(
batch_size=kwargs.get("batch_size", self.batch_size),
max_length=kwargs.get("max_length", self.max_length),
num_beams=kwargs.get("num_beams", self.num_beams),
device=torch_device,
length_penalty=kwargs.get("length_penalty", self.length_penalty),
@@ -170,9 +169,7 @@ class BeamSearchTester:
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
# max_length should be only one more than current input_ids to check that eos is correctly appended
max_length = self.sequence_length + 1
beam_scorer = self.prepare_beam_scorer(
num_beam_hyps_to_keep=1, max_length=max_length, length_penalty=1.0, do_early_stopping=False
)
beam_scorer = self.prepare_beam_scorer(num_beam_hyps_to_keep=1, length_penalty=1.0, do_early_stopping=False)
# update beams and append to input_ids
tokens = next_tokens.clone()
@@ -197,6 +194,7 @@ class BeamSearchTester:
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
)
sequences = sequence_output["sequences"]
@@ -225,6 +223,7 @@ class BeamSearchTester:
output_indices,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id,
max_length=max_length,
)
sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"]