Remove max length beam scorer (#11378)
* removed max_len * removed max_length from BeamSearchScorer * correct max length * finish * del vim * finish & add test Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -59,7 +59,6 @@ class BeamSearchTester:
|
||||
def prepare_beam_scorer(self, **kwargs):
|
||||
return BeamSearchScorer(
|
||||
batch_size=kwargs.get("batch_size", self.batch_size),
|
||||
max_length=kwargs.get("max_length", self.max_length),
|
||||
num_beams=kwargs.get("num_beams", self.num_beams),
|
||||
device=torch_device,
|
||||
length_penalty=kwargs.get("length_penalty", self.length_penalty),
|
||||
@@ -170,9 +169,7 @@ class BeamSearchTester:
|
||||
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
|
||||
# max_length should be only one more than current input_ids to check that eos is correctly appended
|
||||
max_length = self.sequence_length + 1
|
||||
beam_scorer = self.prepare_beam_scorer(
|
||||
num_beam_hyps_to_keep=1, max_length=max_length, length_penalty=1.0, do_early_stopping=False
|
||||
)
|
||||
beam_scorer = self.prepare_beam_scorer(num_beam_hyps_to_keep=1, length_penalty=1.0, do_early_stopping=False)
|
||||
|
||||
# update beams and append to input_ids
|
||||
tokens = next_tokens.clone()
|
||||
@@ -197,6 +194,7 @@ class BeamSearchTester:
|
||||
output_indices,
|
||||
pad_token_id=self.pad_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
sequences = sequence_output["sequences"]
|
||||
@@ -225,6 +223,7 @@ class BeamSearchTester:
|
||||
output_indices,
|
||||
pad_token_id=self.pad_token_id,
|
||||
eos_token_id=self.eos_token_id,
|
||||
max_length=max_length,
|
||||
)
|
||||
sequences = sequence_output["sequences"]
|
||||
sequence_scores = sequence_output["sequence_scores"]
|
||||
|
||||
Reference in New Issue
Block a user