From ee27b3d7df397a44dc88324e5aa639a20bf67e53 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Wed, 16 Mar 2022 09:08:55 +0000 Subject: [PATCH] Replace all deprecated `jax.ops` operations with jnp's `at` (#16078) * Replace all deprecated `jax.ops` operations with jnp's `at` * np to jnp scores * suggested changes --- .../jax-projects/model_parallel/README.md | 2 +- .../generation_flax_logits_process.py | 19 +++++++------------ .../models/bart/modeling_flax_bart.py | 2 +- .../models/big_bird/modeling_flax_big_bird.py | 2 +- .../blenderbot/modeling_flax_blenderbot.py | 2 +- .../modeling_flax_blenderbot_small.py | 2 +- .../models/marian/modeling_flax_marian.py | 4 ++-- .../models/mbart/modeling_flax_mbart.py | 2 +- .../models/wav2vec2/modeling_flax_wav2vec2.py | 8 ++++---- .../models/xglm/modeling_flax_xglm.py | 2 +- ...ax_{{cookiecutter.lowercase_modelname}}.py | 4 ++-- .../test_generation_flax_logits_process.py | 6 +++--- .../generation/test_generation_flax_utils.py | 7 +++---- 13 files changed, 28 insertions(+), 34 deletions(-) diff --git a/examples/research_projects/jax-projects/model_parallel/README.md b/examples/research_projects/jax-projects/model_parallel/README.md index 53d1085542..1ff61e9d73 100644 --- a/examples/research_projects/jax-projects/model_parallel/README.md +++ b/examples/research_projects/jax-projects/model_parallel/README.md @@ -35,7 +35,7 @@ model = FlaxGPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B") emb = jnp.zeros((50264, model.config.hidden_size)) # update the first 50257 weights using pre-trained weights -emb = jax.ops.index_update(emb, jax.ops.index[:50257, :], model.params["transformer"]["wte"]["embedding"]) +emb = emb.at[jax.ops.index[:50257, :]].set(model.params["transformer"]["wte"]["embedding"]) params = model.params params["transformer"]["wte"]["embedding"] = emb diff --git a/src/transformers/generation_flax_logits_process.py b/src/transformers/generation_flax_logits_process.py index e4ebc343e6..3f9792b0fe 100644 --- a/src/transformers/generation_flax_logits_process.py +++ b/src/transformers/generation_flax_logits_process.py @@ -142,10 +142,11 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper): score_mask = cumulative_probs < self.top_p # include the token that is higher than top_p as well - score_mask |= jax.ops.index_update(jnp.roll(score_mask, 1), jax.ops.index[:, 0], True) + score_mask = jnp.roll(score_mask, 1) + score_mask |= score_mask.at[jax.ops.index[:, 0]].set(True) # min tokens to keep - score_mask = jax.ops.index_update(score_mask, jax.ops.index[:, : self.min_tokens_to_keep], True) + score_mask = score_mask.at[jax.ops.index[:, : self.min_tokens_to_keep]].set(True) topk_next_scores = jnp.where(score_mask, topk_scores, mask_scores) next_scores = jax.lax.sort_key_val(topk_indices, topk_next_scores)[-1] @@ -184,7 +185,7 @@ class FlaxTopKLogitsWarper(FlaxLogitsWarper): topk_scores_flat = topk_scores.flatten() topk_indices_flat = topk_indices.flatten() + shift - next_scores_flat = jax.ops.index_update(next_scores_flat, topk_indices_flat, topk_scores_flat) + next_scores_flat = next_scores_flat.at[topk_indices_flat].set(topk_scores_flat) next_scores = next_scores_flat.reshape(batch_size, vocab_size) return next_scores @@ -206,9 +207,7 @@ class FlaxForcedBOSTokenLogitsProcessor(FlaxLogitsProcessor): apply_penalty = 1 - jnp.bool_(cur_len - 1) - scores = jnp.where( - apply_penalty, jax.ops.index_update(new_scores, jax.ops.index[:, self.bos_token_id], 0), scores - ) + scores = jnp.where(apply_penalty, new_scores.at[jax.ops.index[:, self.bos_token_id]].set(0), scores) return scores @@ -233,9 +232,7 @@ class FlaxForcedEOSTokenLogitsProcessor(FlaxLogitsProcessor): apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1) - scores = jnp.where( - apply_penalty, jax.ops.index_update(new_scores, jax.ops.index[:, self.eos_token_id], 0), scores - ) + scores = jnp.where(apply_penalty, new_scores.at[jax.ops.index[:, self.eos_token_id]].set(0), scores) return scores @@ -266,8 +263,6 @@ class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor): # create boolean flag to decide if min length penalty should be applied apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1) - scores = jnp.where( - apply_penalty, jax.ops.index_update(scores, jax.ops.index[:, self.eos_token_id], -float("inf")), scores - ) + scores = jnp.where(apply_penalty, scores.at[jax.ops.index[:, self.eos_token_id]].set(-float("inf")), scores) return scores diff --git a/src/transformers/models/bart/modeling_flax_bart.py b/src/transformers/models/bart/modeling_flax_bart.py index 386bddbb26..2bb00abdd9 100644 --- a/src/transformers/models/bart/modeling_flax_bart.py +++ b/src/transformers/models/bart/modeling_flax_bart.py @@ -922,7 +922,7 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel): # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") # make sure initialization pass will work for FlaxBartForSequenceClassificationModule - input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id) + input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id) attention_mask = jnp.ones_like(input_ids) decoder_input_ids = input_ids decoder_attention_mask = jnp.ones_like(input_ids) diff --git a/src/transformers/models/big_bird/modeling_flax_big_bird.py b/src/transformers/models/big_bird/modeling_flax_big_bird.py index c926c58b60..63f92023a4 100644 --- a/src/transformers/models/big_bird/modeling_flax_big_bird.py +++ b/src/transformers/models/big_bird/modeling_flax_big_bird.py @@ -2124,7 +2124,7 @@ class FlaxBigBirdForQuestionAnswering(FlaxBigBirdPreTrainedModel): if token_type_ids is None: token_type_ids = (~logits_mask).astype("i4") logits_mask = jnp.expand_dims(logits_mask, axis=2) - logits_mask = jax.ops.index_update(logits_mask, jax.ops.index[:, 0], False) + logits_mask = logits_mask.at[jax.ops.index[:, 0]].set(False) # init input tensors if not passed if token_type_ids is None: diff --git a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py index 52464b61f6..cf0730f45b 100644 --- a/src/transformers/models/blenderbot/modeling_flax_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_flax_blenderbot.py @@ -897,7 +897,7 @@ class FlaxBlenderbotPreTrainedModel(FlaxPreTrainedModel): # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") # make sure initialization pass will work for FlaxBlenderbotForSequenceClassificationModule - input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id) + input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id) attention_mask = jnp.ones_like(input_ids) decoder_input_ids = input_ids decoder_attention_mask = jnp.ones_like(input_ids) diff --git a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py index 2efd1ceed6..bad2ca5e65 100644 --- a/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py @@ -895,7 +895,7 @@ class FlaxBlenderbotSmallPreTrainedModel(FlaxPreTrainedModel): # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") # make sure initialization pass will work for FlaxBlenderbotSmallForSequenceClassificationModule - input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id) + input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id) attention_mask = jnp.ones_like(input_ids) decoder_input_ids = input_ids decoder_attention_mask = jnp.ones_like(input_ids) diff --git a/src/transformers/models/marian/modeling_flax_marian.py b/src/transformers/models/marian/modeling_flax_marian.py index 16b2a5b322..ce1424b374 100644 --- a/src/transformers/models/marian/modeling_flax_marian.py +++ b/src/transformers/models/marian/modeling_flax_marian.py @@ -892,7 +892,7 @@ class FlaxMarianPreTrainedModel(FlaxPreTrainedModel): # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") # make sure initialization pass will work for FlaxMarianForSequenceClassificationModule - input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id) + input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id) attention_mask = jnp.ones_like(input_ids) decoder_input_ids = input_ids decoder_attention_mask = jnp.ones_like(input_ids) @@ -1422,7 +1422,7 @@ class FlaxMarianMTModel(FlaxMarianPreTrainedModel): def _adapt_logits_for_beam_search(self, logits): """This function enforces the padding token never to be generated.""" - logits = jax.ops.index_update(logits, jax.ops.index[:, :, self.config.pad_token_id], float("-inf")) + logits = logits.at[jax.ops.index[:, :, self.config.pad_token_id]].set(float("-inf")) return logits def prepare_inputs_for_generation( diff --git a/src/transformers/models/mbart/modeling_flax_mbart.py b/src/transformers/models/mbart/modeling_flax_mbart.py index 6e11faa5d0..6130dad25d 100644 --- a/src/transformers/models/mbart/modeling_flax_mbart.py +++ b/src/transformers/models/mbart/modeling_flax_mbart.py @@ -960,7 +960,7 @@ class FlaxMBartPreTrainedModel(FlaxPreTrainedModel): # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") # make sure initialization pass will work for FlaxMBartForSequenceClassificationModule - input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id) + input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id) attention_mask = jnp.ones_like(input_ids) decoder_input_ids = input_ids decoder_attention_mask = jnp.ones_like(input_ids) diff --git a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py index 68df19f206..95a4072354 100644 --- a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py @@ -964,9 +964,8 @@ class FlaxWav2Vec2Module(nn.Module): # these two operations makes sure that all values # before the output lengths indices are attended to - attention_mask = jax.ops.index_update( - attention_mask, jax.ops.index[jnp.arange(attention_mask.shape[0]), output_lengths - 1], 1 - ) + idx = jax.ops.index[jnp.arange(attention_mask.shape[0]), output_lengths - 1] + attention_mask = attention_mask.at[idx].set(1) attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool") hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic) @@ -1038,7 +1037,8 @@ class FlaxWav2Vec2Module(nn.Module): attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype) # these two operations makes sure that all values before the output lengths idxs are attended to - attention_mask = attention_mask.at[(jnp.arange(attention_mask.shape[0]), output_lengths - 1)].set(1) + idx = (jnp.arange(attention_mask.shape[0]), output_lengths - 1) + attention_mask = attention_mask.at[idx].set(1) attention_mask = jnp.flip(jnp.flip(attention_mask, axis=-1).cumsum(axis=-1), axis=-1) attention_mask = jnp.array(attention_mask, dtype=bool) diff --git a/src/transformers/models/xglm/modeling_flax_xglm.py b/src/transformers/models/xglm/modeling_flax_xglm.py index 37f2ae9d96..7f5adb52a1 100644 --- a/src/transformers/models/xglm/modeling_flax_xglm.py +++ b/src/transformers/models/xglm/modeling_flax_xglm.py @@ -131,7 +131,7 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_ Shift input ids one token to the right. """ shifted_input_ids = jnp.roll(input_ids, 1, axis=-1) - shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id) + shifted_input_ids = shifted_input_ids.at[(..., 0)].set(decoder_start_token_id) # replace possible -100 values in labels by `pad_token_id` shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py index 45b906e6e3..7c33e637c9 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py @@ -1330,7 +1330,7 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_ Shift input ids one token to the right. """ shifted_input_ids = jnp.roll(input_ids, 1, axis=-1) - shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id) + shifted_input_ids = shifted_input_ids.at[(..., 0)].set(decoder_start_token_id) # replace possible -100 values in labels by `pad_token_id` shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) @@ -2040,7 +2040,7 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") # make sure initialization pass will work for Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule - input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id) + input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id) attention_mask = jnp.ones_like(input_ids) decoder_input_ids = input_ids decoder_attention_mask = jnp.ones_like(input_ids) diff --git a/tests/generation/test_generation_flax_logits_process.py b/tests/generation/test_generation_flax_logits_process.py index 43ed7923d2..aea44252d9 100644 --- a/tests/generation/test_generation_flax_logits_process.py +++ b/tests/generation/test_generation_flax_logits_process.py @@ -41,7 +41,7 @@ if is_flax_available(): @require_flax class LogitsProcessorTest(unittest.TestCase): def _get_uniform_logits(self, batch_size: int, length: int): - scores = np.ones((batch_size, length)) / length + scores = jnp.ones((batch_size, length)) / length return scores def test_temperature_dist_warper(self): @@ -51,8 +51,8 @@ class LogitsProcessorTest(unittest.TestCase): scores = self._get_uniform_logits(batch_size=2, length=length) # tweak scores to not be uniform anymore - scores[1, 5] = (1 / length) + 0.1 # peak, 1st batch - scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch + scores = scores.at[1, 5].set((1 / length) + 0.1) # peak, 1st batch + scores = scores.at[1, 10].set((1 / length) - 0.4) # valley, 1st batch # compute softmax probs = jax.nn.softmax(scores, axis=-1) diff --git a/tests/generation/test_generation_flax_utils.py b/tests/generation/test_generation_flax_utils.py index ebf0b27875..b7b84d8db7 100644 --- a/tests/generation/test_generation_flax_utils.py +++ b/tests/generation/test_generation_flax_utils.py @@ -24,7 +24,6 @@ from transformers.testing_utils import is_pt_flax_cross_test, require_flax if is_flax_available(): import os - import jax import jax.numpy as jnp from jax import jit from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model @@ -219,7 +218,7 @@ class FlaxGenerationTesterMixin: config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() # pad attention mask on the left - attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0) + attention_mask = attention_mask.at[(0, 0)].set(0) config.do_sample = False config.max_length = max_length @@ -239,7 +238,7 @@ class FlaxGenerationTesterMixin: config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() # pad attention mask on the left - attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0) + attention_mask = attention_mask.at[(0, 0)].set(0) config.do_sample = True config.max_length = max_length @@ -259,7 +258,7 @@ class FlaxGenerationTesterMixin: config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() # pad attention mask on the left - attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0) + attention_mask = attention_mask.at[(0, 0)].set(0) config.num_beams = 2 config.max_length = max_length