🚨🚨 Generate: standardize beam search behavior across frameworks (#21368)

This commit is contained in:
Joao Gante
2023-02-03 10:24:02 +00:00
committed by GitHub
parent ea55bd86b9
commit f21af26279
10 changed files with 122 additions and 118 deletions

View File

@@ -426,7 +426,7 @@ class FlaxBartModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationT
)
input_ids = tokenizer(input_str, return_tensors="np").input_ids
sequences = model.generate(input_ids, num_beams=2, max_length=20).sequences
sequences = model.generate(input_ids, num_beams=2, min_length=None, max_length=20).sequences
output_str = tokenizer.batch_decode(sequences)[0]