From b013842244df7be96b8cc841491bd1e35e475e36 Mon Sep 17 00:00:00 2001 From: Martin Schmitt Date: Tue, 2 Mar 2021 08:41:54 +0100 Subject: [PATCH] Changed `num_beams` to `num_beams // num_beam_groups` when initialising `PrefixConstrainedLogitsProcessor` in `_get_logits_processor` to fix compatibility issue when constrained decoding is used together with grouped beam search (#10475) --- src/transformers/generation_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index cf94d15e32..1a75718b1d 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -605,7 +605,7 @@ class GenerationMixin: if min_length is not None and eos_token_id is not None and min_length > -1: processors.append(MinLengthLogitsProcessor(min_length, eos_token_id)) if prefix_allowed_tokens_fn is not None: - processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams)) + processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams // num_beam_groups)) if forced_bos_token_id is not None: processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id)) if forced_eos_token_id is not None: