Changed num_beams to num_beams // num_beam_groups when initialising PrefixConstrainedLogitsProcessor in _get_logits_processor to fix compatibility issue when constrained decoding is used together with grouped beam search (#10475)
This commit is contained in:
@@ -605,7 +605,7 @@ class GenerationMixin:
|
|||||||
if min_length is not None and eos_token_id is not None and min_length > -1:
|
if min_length is not None and eos_token_id is not None and min_length > -1:
|
||||||
processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
|
processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
|
||||||
if prefix_allowed_tokens_fn is not None:
|
if prefix_allowed_tokens_fn is not None:
|
||||||
processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams))
|
processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams // num_beam_groups))
|
||||||
if forced_bos_token_id is not None:
|
if forced_bos_token_id is not None:
|
||||||
processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
|
processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
|
||||||
if forced_eos_token_id is not None:
|
if forced_eos_token_id is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user