From a6e6b1c622d8d08e2510a82cb6266d7b654f1cbf Mon Sep 17 00:00:00 2001 From: mariecwhite Date: Sat, 5 Aug 2023 03:36:57 +1000 Subject: [PATCH] Remove jnp.DeviceArray since it is deprecated. (#24875) * Remove jnp.DeviceArray since it is deprecated. * Replace all instances of jnp.DeviceArray with jax.Array * Update src/transformers/models/bert/modeling_flax_bert.py --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> --- src/transformers/models/bart/modeling_flax_bart.py | 6 +++--- src/transformers/models/bert/modeling_flax_bert.py | 2 +- src/transformers/models/big_bird/modeling_flax_big_bird.py | 2 +- .../models/blenderbot/modeling_flax_blenderbot.py | 4 ++-- .../blenderbot_small/modeling_flax_blenderbot_small.py | 4 ++-- src/transformers/models/electra/modeling_flax_electra.py | 2 +- .../models/encoder_decoder/modeling_flax_encoder_decoder.py | 4 ++-- src/transformers/models/gpt2/modeling_flax_gpt2.py | 2 +- src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py | 2 +- src/transformers/models/gptj/modeling_flax_gptj.py | 2 +- src/transformers/models/longt5/modeling_flax_longt5.py | 4 ++-- src/transformers/models/marian/modeling_flax_marian.py | 4 ++-- src/transformers/models/mbart/modeling_flax_mbart.py | 4 ++-- src/transformers/models/opt/modeling_flax_opt.py | 2 +- src/transformers/models/pegasus/modeling_flax_pegasus.py | 4 ++-- src/transformers/models/roberta/modeling_flax_roberta.py | 2 +- .../modeling_flax_roberta_prelayernorm.py | 2 +- .../modeling_flax_speech_encoder_decoder.py | 4 ++-- src/transformers/models/t5/modeling_flax_t5.py | 4 ++-- .../modeling_flax_vision_encoder_decoder.py | 2 +- src/transformers/models/whisper/modeling_flax_whisper.py | 4 ++-- src/transformers/models/xglm/modeling_flax_xglm.py | 2 +- .../models/xlm_roberta/modeling_flax_xlm_roberta.py | 2 +- .../modeling_flax_{{cookiecutter.lowercase_modelname}}.py | 6 +++--- 24 files changed, 38 insertions(+), 38 deletions(-) diff --git a/src/transformers/models/bart/modeling_flax_bart.py b/src/transformers/models/bart/modeling_flax_bart.py index b7ce63ffcc..9858eb2d1b 100644 --- a/src/transformers/models/bart/modeling_flax_bart.py +++ b/src/transformers/models/bart/modeling_flax_bart.py @@ -1467,8 +1467,8 @@ class FlaxBartForConditionalGeneration(FlaxBartPreTrainedModel): self, decoder_input_ids, max_length, - attention_mask: Optional[jnp.DeviceArray] = None, - decoder_attention_mask: Optional[jnp.DeviceArray] = None, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, encoder_outputs=None, **kwargs, ): @@ -1960,7 +1960,7 @@ class FlaxBartForCausalLMModule(nn.Module): class FlaxBartForCausalLM(FlaxBartDecoderPreTrainedModel): module_class = FlaxBartForCausalLMModule - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): # initializing the cache batch_size, seq_length = input_ids.shape diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index 6e8eb829b9..99dfa2a0e2 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -1677,7 +1677,7 @@ class FlaxBertForCausalLMModule(nn.Module): class FlaxBertForCausalLM(FlaxBertPreTrainedModel): module_class = FlaxBertForCausalLMModule - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): # initializing the cache batch_size, seq_length = input_ids.shape diff --git a/src/transformers/models/big_bird/modeling_flax_big_bird.py b/src/transformers/models/big_bird/modeling_flax_big_bird.py index a6f503aa3e..afdac2645f 100644 --- a/src/transformers/models/big_bird/modeling_flax_big_bird.py +++ b/src/transformers/models/big_bird/modeling_flax_big_bird.py @@ -2599,7 +2599,7 @@ class FlaxBigBirdForCausalLMModule(nn.Module): class FlaxBigBirdForCausalLM(FlaxBigBirdPreTrainedModel): module_class = FlaxBigBirdForCausalLMModule - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): # initializing the cache batch_size, seq_length = input_ids.shape diff --git a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py index 3f5c73a6c3..1035272fd0 100644 --- a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py @@ -1443,8 +1443,8 @@ class FlaxBlenderbotForConditionalGeneration(FlaxBlenderbotPreTrainedModel): self, decoder_input_ids, max_length, - attention_mask: Optional[jnp.DeviceArray] = None, - decoder_attention_mask: Optional[jnp.DeviceArray] = None, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, encoder_outputs=None, **kwargs, ): diff --git a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py index 77e6b1704b..2bf8b59e27 100644 --- a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py @@ -1441,8 +1441,8 @@ class FlaxBlenderbotSmallForConditionalGeneration(FlaxBlenderbotSmallPreTrainedM self, decoder_input_ids, max_length, - attention_mask: Optional[jnp.DeviceArray] = None, - decoder_attention_mask: Optional[jnp.DeviceArray] = None, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, encoder_outputs=None, **kwargs, ): diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py index f7c150f56d..32e76b8b58 100644 --- a/src/transformers/models/electra/modeling_flax_electra.py +++ b/src/transformers/models/electra/modeling_flax_electra.py @@ -1565,7 +1565,7 @@ class FlaxElectraForCausalLMModule(nn.Module): class FlaxElectraForCausalLM(FlaxElectraPreTrainedModel): module_class = FlaxElectraForCausalLMModule - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): # initializing the cache batch_size, seq_length = input_ids.shape diff --git a/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py index a500398d67..3d9679f26a 100644 --- a/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py @@ -722,8 +722,8 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel): self, decoder_input_ids, max_length, - attention_mask: Optional[jnp.DeviceArray] = None, - decoder_attention_mask: Optional[jnp.DeviceArray] = None, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, encoder_outputs=None, **kwargs, ): diff --git a/src/transformers/models/gpt2/modeling_flax_gpt2.py b/src/transformers/models/gpt2/modeling_flax_gpt2.py index 2a449360b8..8973e081a3 100644 --- a/src/transformers/models/gpt2/modeling_flax_gpt2.py +++ b/src/transformers/models/gpt2/modeling_flax_gpt2.py @@ -742,7 +742,7 @@ class FlaxGPT2LMHeadModule(nn.Module): class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel): module_class = FlaxGPT2LMHeadModule - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): # initializing the cache batch_size, seq_length = input_ids.shape diff --git a/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py index 0749911f7a..5639ca50f1 100644 --- a/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py @@ -654,7 +654,7 @@ class FlaxGPTNeoForCausalLMModule(nn.Module): class FlaxGPTNeoForCausalLM(FlaxGPTNeoPreTrainedModel): module_class = FlaxGPTNeoForCausalLMModule - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): # initializing the cache batch_size, seq_length = input_ids.shape diff --git a/src/transformers/models/gptj/modeling_flax_gptj.py b/src/transformers/models/gptj/modeling_flax_gptj.py index 8ec53aec46..9f0d4d6e86 100644 --- a/src/transformers/models/gptj/modeling_flax_gptj.py +++ b/src/transformers/models/gptj/modeling_flax_gptj.py @@ -683,7 +683,7 @@ class FlaxGPTJForCausalLMModule(nn.Module): class FlaxGPTJForCausalLM(FlaxGPTJPreTrainedModel): module_class = FlaxGPTJForCausalLMModule - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): # initializing the cache batch_size, seq_length = input_ids.shape diff --git a/src/transformers/models/longt5/modeling_flax_longt5.py b/src/transformers/models/longt5/modeling_flax_longt5.py index 96c0b7df2c..6b7bc7c28f 100644 --- a/src/transformers/models/longt5/modeling_flax_longt5.py +++ b/src/transformers/models/longt5/modeling_flax_longt5.py @@ -2388,8 +2388,8 @@ class FlaxLongT5ForConditionalGeneration(FlaxLongT5PreTrainedModel): self, decoder_input_ids, max_length, - attention_mask: Optional[jnp.DeviceArray] = None, - decoder_attention_mask: Optional[jnp.DeviceArray] = None, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, encoder_outputs=None, **kwargs, ): diff --git a/src/transformers/models/marian/modeling_flax_marian.py b/src/transformers/models/marian/modeling_flax_marian.py index f197126277..a713fdb05d 100644 --- a/src/transformers/models/marian/modeling_flax_marian.py +++ b/src/transformers/models/marian/modeling_flax_marian.py @@ -1436,8 +1436,8 @@ class FlaxMarianMTModel(FlaxMarianPreTrainedModel): self, decoder_input_ids, max_length, - attention_mask: Optional[jnp.DeviceArray] = None, - decoder_attention_mask: Optional[jnp.DeviceArray] = None, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, encoder_outputs=None, **kwargs, ): diff --git a/src/transformers/models/mbart/modeling_flax_mbart.py b/src/transformers/models/mbart/modeling_flax_mbart.py index aeeec3e583..907fd53aa1 100644 --- a/src/transformers/models/mbart/modeling_flax_mbart.py +++ b/src/transformers/models/mbart/modeling_flax_mbart.py @@ -1502,8 +1502,8 @@ class FlaxMBartForConditionalGeneration(FlaxMBartPreTrainedModel): self, decoder_input_ids, max_length, - attention_mask: Optional[jnp.DeviceArray] = None, - decoder_attention_mask: Optional[jnp.DeviceArray] = None, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, encoder_outputs=None, **kwargs, ): diff --git a/src/transformers/models/opt/modeling_flax_opt.py b/src/transformers/models/opt/modeling_flax_opt.py index b7038008f5..5d9839f120 100644 --- a/src/transformers/models/opt/modeling_flax_opt.py +++ b/src/transformers/models/opt/modeling_flax_opt.py @@ -763,7 +763,7 @@ class FlaxOPTForCausalLMModule(nn.Module): class FlaxOPTForCausalLM(FlaxOPTPreTrainedModel): module_class = FlaxOPTForCausalLMModule - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): # initializing the cache batch_size, seq_length = input_ids.shape diff --git a/src/transformers/models/pegasus/modeling_flax_pegasus.py b/src/transformers/models/pegasus/modeling_flax_pegasus.py index fdf7f019f2..c5189746b1 100644 --- a/src/transformers/models/pegasus/modeling_flax_pegasus.py +++ b/src/transformers/models/pegasus/modeling_flax_pegasus.py @@ -1450,8 +1450,8 @@ class FlaxPegasusForConditionalGeneration(FlaxPegasusPreTrainedModel): self, decoder_input_ids, max_length, - attention_mask: Optional[jnp.DeviceArray] = None, - decoder_attention_mask: Optional[jnp.DeviceArray] = None, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, encoder_outputs=None, **kwargs, ): diff --git a/src/transformers/models/roberta/modeling_flax_roberta.py b/src/transformers/models/roberta/modeling_flax_roberta.py index 5de7375d38..845fcea442 100644 --- a/src/transformers/models/roberta/modeling_flax_roberta.py +++ b/src/transformers/models/roberta/modeling_flax_roberta.py @@ -1452,7 +1452,7 @@ class FlaxRobertaForCausalLMModule(nn.Module): class FlaxRobertaForCausalLM(FlaxRobertaPreTrainedModel): module_class = FlaxRobertaForCausalLMModule - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): # initializing the cache batch_size, seq_length = input_ids.shape diff --git a/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py index 8f5dc7944c..b7c347693d 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py @@ -1478,7 +1478,7 @@ class FlaxRobertaPreLayerNormForCausalLMModule(nn.Module): class FlaxRobertaPreLayerNormForCausalLM(FlaxRobertaPreLayerNormPreTrainedModel): module_class = FlaxRobertaPreLayerNormForCausalLMModule - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): # initializing the cache batch_size, seq_length = input_ids.shape diff --git a/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py index d7e7bdf57f..b9975510ab 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py @@ -745,8 +745,8 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): self, decoder_input_ids, max_length, - attention_mask: Optional[jnp.DeviceArray] = None, - decoder_attention_mask: Optional[jnp.DeviceArray] = None, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, encoder_outputs=None, **kwargs, ): diff --git a/src/transformers/models/t5/modeling_flax_t5.py b/src/transformers/models/t5/modeling_flax_t5.py index cc74c30c1d..b2a7181421 100644 --- a/src/transformers/models/t5/modeling_flax_t5.py +++ b/src/transformers/models/t5/modeling_flax_t5.py @@ -1740,8 +1740,8 @@ class FlaxT5ForConditionalGeneration(FlaxT5PreTrainedModel): self, decoder_input_ids, max_length, - attention_mask: Optional[jnp.DeviceArray] = None, - decoder_attention_mask: Optional[jnp.DeviceArray] = None, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, encoder_outputs=None, **kwargs, ): diff --git a/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py index 8561875ed5..3d914c9658 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py @@ -688,7 +688,7 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel): self, decoder_input_ids, max_length, - decoder_attention_mask: Optional[jnp.DeviceArray] = None, + decoder_attention_mask: Optional[jax.Array] = None, encoder_outputs=None, **kwargs, ): diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index c9f3bbba9c..0f158fb602 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -1448,8 +1448,8 @@ class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel): self, decoder_input_ids, max_length, - attention_mask: Optional[jnp.DeviceArray] = None, - decoder_attention_mask: Optional[jnp.DeviceArray] = None, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, encoder_outputs=None, **kwargs, ): diff --git a/src/transformers/models/xglm/modeling_flax_xglm.py b/src/transformers/models/xglm/modeling_flax_xglm.py index b2acd66f44..d6b90a7f00 100644 --- a/src/transformers/models/xglm/modeling_flax_xglm.py +++ b/src/transformers/models/xglm/modeling_flax_xglm.py @@ -766,7 +766,7 @@ class FlaxXGLMForCausalLMModule(nn.Module): class FlaxXGLMForCausalLM(FlaxXGLMPreTrainedModel): module_class = FlaxXGLMForCausalLMModule - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): # initializing the cache batch_size, seq_length = input_ids.shape diff --git a/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py index ed5e113770..f6f39ee93b 100644 --- a/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py @@ -1469,7 +1469,7 @@ class FlaxXLMRobertaForCausalLMModule(nn.Module): class FlaxXLMRobertaForCausalLM(FlaxXLMRobertaPreTrainedModel): module_class = FlaxXLMRobertaForCausalLMModule - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): # initializing the cache batch_size, seq_length = input_ids.shape diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py index f6283197b0..83263a6a47 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py @@ -1469,7 +1469,7 @@ class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module): class Flax{{cookiecutter.camelcase_modelname}}ForCausalLM(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel): module_class = Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule - def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): # initializing the cache batch_size, seq_length = input_ids.shape @@ -2969,8 +2969,8 @@ class Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(Flax{{coo self, decoder_input_ids, max_length, - attention_mask: Optional[jnp.DeviceArray] = None, - decoder_attention_mask: Optional[jnp.DeviceArray] = None, + attention_mask: Optional[jax.Array] = None, + decoder_attention_mask: Optional[jax.Array] = None, encoder_outputs=None, **kwargs ):