[tf/flax] handle forced_decoder_ids deletion (#38316)
fix tf/flax, attr checks
This commit is contained in:
@@ -531,13 +531,16 @@ class FlaxGenerationMixin:
|
||||
if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
|
||||
else begin_index + 1
|
||||
)
|
||||
if generation_config.forced_decoder_ids is not None and len(generation_config.forced_decoder_ids) > 0:
|
||||
if (
|
||||
getattr(generation_config, "forced_decoder_ids", None) is not None
|
||||
and len(generation_config.forced_decoder_ids) > 0
|
||||
):
|
||||
# generation starts after the last token that is forced
|
||||
begin_index += generation_config.forced_decoder_ids[-1][0]
|
||||
processors.append(
|
||||
FlaxSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
|
||||
)
|
||||
if generation_config.forced_decoder_ids is not None:
|
||||
if getattr(generation_config, "forced_decoder_ids", None) is not None:
|
||||
forced_decoder_ids = [
|
||||
[input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids
|
||||
]
|
||||
|
||||
@@ -1490,14 +1490,14 @@ class TFGenerationMixin:
|
||||
if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
|
||||
else begin_index + 1
|
||||
)
|
||||
if generation_config.forced_decoder_ids is not None:
|
||||
if getattr(generation_config, "forced_decoder_ids", None) is not None:
|
||||
begin_index += generation_config.forced_decoder_ids[-1][
|
||||
0
|
||||
] # generation starts after the last token that is forced
|
||||
processors.append(
|
||||
TFSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
|
||||
)
|
||||
if generation_config.forced_decoder_ids is not None:
|
||||
if getattr(generation_config, "forced_decoder_ids", None) is not None:
|
||||
processors.append(TFForceTokensLogitsProcessor(generation_config.forced_decoder_ids))
|
||||
|
||||
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
||||
|
||||
Reference in New Issue
Block a user