From a94105f95fb66ee4129077c03e4e8a224f6a07fd Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Wed, 15 Dec 2021 11:36:28 +0100 Subject: [PATCH] Fix preprocess_function in run_summarization_flax.py (#14769) Co-authored-by: ydshieh --- examples/flax/summarization/run_summarization_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/flax/summarization/run_summarization_flax.py b/examples/flax/summarization/run_summarization_flax.py index f82ddeee96..9bb43b89a4 100644 --- a/examples/flax/summarization/run_summarization_flax.py +++ b/examples/flax/summarization/run_summarization_flax.py @@ -533,7 +533,7 @@ def main(): model_inputs["labels"] = labels["input_ids"] decoder_input_ids = shift_tokens_right_fn( - jnp.array(labels["input_ids"]), config.pad_token_id, config.decoder_start_token_id + labels["input_ids"], config.pad_token_id, config.decoder_start_token_id ) model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)