From fc63914399b6f60512c720959f9182b02ae4a45c Mon Sep 17 00:00:00 2001 From: Roy Hvaara Date: Tue, 10 Oct 2023 12:35:16 -0700 Subject: [PATCH] [JAX] Replace uses of `jnp.array` in types with `jnp.ndarray`. (#26703) `jnp.array` is a function, not a type: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html so it never makes sense to use `jnp.array` in a type annotation. Presumably the intent was to write `jnp.ndarray` aka `jax.Array`. Co-authored-by: Peter Hawkins --- .../flax/image-captioning/run_image_captioning_flax.py | 2 +- examples/flax/language-modeling/run_clm_flax.py | 2 +- examples/flax/question-answering/run_qa.py | 2 +- .../run_flax_speech_recognition_seq2seq.py | 2 +- examples/flax/summarization/run_summarization_flax.py | 2 +- examples/flax/text-classification/run_flax_glue.py | 2 +- examples/flax/token-classification/run_flax_ner.py | 2 +- examples/flax/vision/run_image_classification.py | 2 +- .../jax-projects/hybrid_clip/run_hybrid_clip.py | 2 +- .../jax-projects/model_parallel/run_clm_mp.py | 2 +- src/transformers/models/bart/modeling_flax_bart.py | 2 +- src/transformers/models/bert/modeling_flax_bert.py | 2 +- .../models/big_bird/modeling_flax_big_bird.py | 2 +- .../models/blenderbot/modeling_flax_blenderbot.py | 2 +- .../blenderbot_small/modeling_flax_blenderbot_small.py | 2 +- src/transformers/models/electra/modeling_flax_electra.py | 8 ++++---- src/transformers/models/longt5/modeling_flax_longt5.py | 2 +- src/transformers/models/marian/modeling_flax_marian.py | 2 +- src/transformers/models/mt5/modeling_flax_mt5.py | 2 +- src/transformers/models/pegasus/modeling_flax_pegasus.py | 2 +- src/transformers/models/roberta/modeling_flax_roberta.py | 2 +- .../modeling_flax_roberta_prelayernorm.py | 2 +- src/transformers/models/t5/modeling_flax_t5.py | 2 +- .../models/xlm_roberta/modeling_flax_xlm_roberta.py | 2 +- .../modeling_flax_{{cookiecutter.lowercase_modelname}}.py | 2 +- 25 files changed, 28 insertions(+), 28 deletions(-) diff --git a/examples/flax/image-captioning/run_image_captioning_flax.py b/examples/flax/image-captioning/run_image_captioning_flax.py index bbc79977a4..d8c89c1a24 100644 --- a/examples/flax/image-captioning/run_image_captioning_flax.py +++ b/examples/flax/image-captioning/run_image_captioning_flax.py @@ -381,7 +381,7 @@ def write_metric(summary_writer, metrics, train_time, step, metric_key_prefix="t def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/examples/flax/language-modeling/run_clm_flax.py b/examples/flax/language-modeling/run_clm_flax.py index 95e175d494..c61b24f4d7 100755 --- a/examples/flax/language-modeling/run_clm_flax.py +++ b/examples/flax/language-modeling/run_clm_flax.py @@ -326,7 +326,7 @@ def write_eval_metric(summary_writer, eval_metrics, step): def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/examples/flax/question-answering/run_qa.py b/examples/flax/question-answering/run_qa.py index 9cd90f285a..0d35f302f8 100644 --- a/examples/flax/question-answering/run_qa.py +++ b/examples/flax/question-answering/run_qa.py @@ -389,7 +389,7 @@ def create_train_state( # region Create learning rate function def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py b/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py index 4a2915a31a..8af835b6a4 100644 --- a/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py +++ b/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py @@ -360,7 +360,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): def create_learning_rate_fn( num_train_steps: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) decay_fn = optax.linear_schedule( diff --git a/examples/flax/summarization/run_summarization_flax.py b/examples/flax/summarization/run_summarization_flax.py index d57aa17691..782e9ee88c 100644 --- a/examples/flax/summarization/run_summarization_flax.py +++ b/examples/flax/summarization/run_summarization_flax.py @@ -409,7 +409,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/examples/flax/text-classification/run_flax_glue.py b/examples/flax/text-classification/run_flax_glue.py index b42a256531..1535ff8492 100755 --- a/examples/flax/text-classification/run_flax_glue.py +++ b/examples/flax/text-classification/run_flax_glue.py @@ -288,7 +288,7 @@ def create_train_state( def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/examples/flax/token-classification/run_flax_ner.py b/examples/flax/token-classification/run_flax_ner.py index f4b40220ff..e06a85cb67 100644 --- a/examples/flax/token-classification/run_flax_ner.py +++ b/examples/flax/token-classification/run_flax_ner.py @@ -340,7 +340,7 @@ def create_train_state( def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/examples/flax/vision/run_image_classification.py b/examples/flax/vision/run_image_classification.py index 66505014ec..4bed9b663f 100644 --- a/examples/flax/vision/run_image_classification.py +++ b/examples/flax/vision/run_image_classification.py @@ -249,7 +249,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py b/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py index f54641408f..c5a4a20253 100644 --- a/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py +++ b/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py @@ -283,7 +283,7 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step): def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py b/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py index a6da8729f0..bb297e3e0d 100644 --- a/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py +++ b/examples/research_projects/jax-projects/model_parallel/run_clm_mp.py @@ -214,7 +214,7 @@ def write_eval_metric(summary_writer, eval_metrics, step): def create_learning_rate_fn( train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float -) -> Callable[[int], jnp.array]: +) -> Callable[[int], jnp.ndarray]: """Returns a linear warmup, linear_decay learning rate function.""" steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_train_epochs diff --git a/src/transformers/models/bart/modeling_flax_bart.py b/src/transformers/models/bart/modeling_flax_bart.py index 9858eb2d1b..6abfcdc398 100644 --- a/src/transformers/models/bart/modeling_flax_bart.py +++ b/src/transformers/models/bart/modeling_flax_bart.py @@ -217,7 +217,7 @@ BART_DECODE_INPUTS_DOCSTRING = r""" """ -def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index 99dfa2a0e2..bb2af0e060 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -295,7 +295,7 @@ class FlaxBertSelfAttention(nn.Module): hidden_states, attention_mask, layer_head_mask, - key_value_states: Optional[jnp.array] = None, + key_value_states: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic=True, output_attentions: bool = False, diff --git a/src/transformers/models/big_bird/modeling_flax_big_bird.py b/src/transformers/models/big_bird/modeling_flax_big_bird.py index afdac2645f..c6d8b7c161 100644 --- a/src/transformers/models/big_bird/modeling_flax_big_bird.py +++ b/src/transformers/models/big_bird/modeling_flax_big_bird.py @@ -316,7 +316,7 @@ class FlaxBigBirdSelfAttention(nn.Module): hidden_states, attention_mask, layer_head_mask, - key_value_states: Optional[jnp.array] = None, + key_value_states: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic=True, output_attentions: bool = False, diff --git a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py index 1035272fd0..61239335be 100644 --- a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py @@ -204,7 +204,7 @@ BLENDERBOT_DECODE_INPUTS_DOCSTRING = r""" # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ diff --git a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py index 2bf8b59e27..b5272fb3bc 100644 --- a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py @@ -216,7 +216,7 @@ BLENDERBOT_SMALL_DECODE_INPUTS_DOCSTRING = r""" # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py index 32e76b8b58..8fced6ff1e 100644 --- a/src/transformers/models/electra/modeling_flax_electra.py +++ b/src/transformers/models/electra/modeling_flax_electra.py @@ -263,7 +263,7 @@ class FlaxElectraSelfAttention(nn.Module): hidden_states, attention_mask, layer_head_mask, - key_value_states: Optional[jnp.array] = None, + key_value_states: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic=True, output_attentions: bool = False, @@ -1228,13 +1228,13 @@ class FlaxElectraSequenceSummary(nn.Module): Compute a single vector summary of a sequence hidden states. Args: - hidden_states (`jnp.array` of shape `[batch_size, seq_len, hidden_size]`): + hidden_states (`jnp.ndarray` of shape `[batch_size, seq_len, hidden_size]`): The hidden states of the last layer. - cls_index (`jnp.array` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): + cls_index (`jnp.ndarray` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. Returns: - `jnp.array`: The summary of the sequence hidden states. + `jnp.ndarray`: The summary of the sequence hidden states. """ # NOTE: this doest "first" type summary always output = hidden_states[:, 0] diff --git a/src/transformers/models/longt5/modeling_flax_longt5.py b/src/transformers/models/longt5/modeling_flax_longt5.py index 91ca9c72c2..36e273d572 100644 --- a/src/transformers/models/longt5/modeling_flax_longt5.py +++ b/src/transformers/models/longt5/modeling_flax_longt5.py @@ -56,7 +56,7 @@ remat = nn_partitioning.remat # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ diff --git a/src/transformers/models/marian/modeling_flax_marian.py b/src/transformers/models/marian/modeling_flax_marian.py index a713fdb05d..5197c90689 100644 --- a/src/transformers/models/marian/modeling_flax_marian.py +++ b/src/transformers/models/marian/modeling_flax_marian.py @@ -227,7 +227,7 @@ def create_sinusoidal_positions(n_pos, dim): # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ diff --git a/src/transformers/models/mt5/modeling_flax_mt5.py b/src/transformers/models/mt5/modeling_flax_mt5.py index 86ddf477ff..0046e02ca7 100644 --- a/src/transformers/models/mt5/modeling_flax_mt5.py +++ b/src/transformers/models/mt5/modeling_flax_mt5.py @@ -27,7 +27,7 @@ _CONFIG_FOR_DOC = "T5Config" # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ diff --git a/src/transformers/models/pegasus/modeling_flax_pegasus.py b/src/transformers/models/pegasus/modeling_flax_pegasus.py index c5189746b1..17772251bf 100644 --- a/src/transformers/models/pegasus/modeling_flax_pegasus.py +++ b/src/transformers/models/pegasus/modeling_flax_pegasus.py @@ -210,7 +210,7 @@ PEGASUS_DECODE_INPUTS_DOCSTRING = r""" # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ diff --git a/src/transformers/models/roberta/modeling_flax_roberta.py b/src/transformers/models/roberta/modeling_flax_roberta.py index 845fcea442..6bc72f12b4 100644 --- a/src/transformers/models/roberta/modeling_flax_roberta.py +++ b/src/transformers/models/roberta/modeling_flax_roberta.py @@ -256,7 +256,7 @@ class FlaxRobertaSelfAttention(nn.Module): hidden_states, attention_mask, layer_head_mask, - key_value_states: Optional[jnp.array] = None, + key_value_states: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic=True, output_attentions: bool = False, diff --git a/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py index b7c347693d..e988979937 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py @@ -258,7 +258,7 @@ class FlaxRobertaPreLayerNormSelfAttention(nn.Module): hidden_states, attention_mask, layer_head_mask, - key_value_states: Optional[jnp.array] = None, + key_value_states: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic=True, output_attentions: bool = False, diff --git a/src/transformers/models/t5/modeling_flax_t5.py b/src/transformers/models/t5/modeling_flax_t5.py index db4ca90c27..09575fdcc3 100644 --- a/src/transformers/models/t5/modeling_flax_t5.py +++ b/src/transformers/models/t5/modeling_flax_t5.py @@ -56,7 +56,7 @@ remat = nn_partitioning.remat # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right -def shift_tokens_right(input_ids: jnp.array, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: +def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ diff --git a/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py index f6f39ee93b..fb03c390f6 100644 --- a/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py @@ -266,7 +266,7 @@ class FlaxXLMRobertaSelfAttention(nn.Module): hidden_states, attention_mask, layer_head_mask, - key_value_states: Optional[jnp.array] = None, + key_value_states: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic=True, output_attentions: bool = False, diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py index 83263a6a47..63b5d83d30 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py @@ -251,7 +251,7 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module): hidden_states, attention_mask, layer_head_mask, - key_value_states: Optional[jnp.array] = None, + key_value_states: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic=True, output_attentions: bool = False,