From 93d3fd86459eb27bc584da29a3d542817a395bca Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Thu, 17 Mar 2022 17:51:43 +0100 Subject: [PATCH] remove jax.ops.index (#16220) --- .../jax-projects/model_parallel/README.md | 2 +- src/transformers/generation_flax_logits_process.py | 10 +++++----- .../models/big_bird/modeling_flax_big_bird.py | 2 +- src/transformers/models/marian/modeling_flax_marian.py | 2 +- .../models/wav2vec2/modeling_flax_wav2vec2.py | 3 +-- 5 files changed, 9 insertions(+), 10 deletions(-) diff --git a/examples/research_projects/jax-projects/model_parallel/README.md b/examples/research_projects/jax-projects/model_parallel/README.md index 1ff61e9d73..6b6998b56a 100644 --- a/examples/research_projects/jax-projects/model_parallel/README.md +++ b/examples/research_projects/jax-projects/model_parallel/README.md @@ -35,7 +35,7 @@ model = FlaxGPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B") emb = jnp.zeros((50264, model.config.hidden_size)) # update the first 50257 weights using pre-trained weights -emb = emb.at[jax.ops.index[:50257, :]].set(model.params["transformer"]["wte"]["embedding"]) +emb = emb.at[:50257, :].set(model.params["transformer"]["wte"]["embedding"]) params = model.params params["transformer"]["wte"]["embedding"] = emb diff --git a/src/transformers/generation_flax_logits_process.py b/src/transformers/generation_flax_logits_process.py index 3f9792b0fe..6db4130c12 100644 --- a/src/transformers/generation_flax_logits_process.py +++ b/src/transformers/generation_flax_logits_process.py @@ -143,10 +143,10 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper): # include the token that is higher than top_p as well score_mask = jnp.roll(score_mask, 1) - score_mask |= score_mask.at[jax.ops.index[:, 0]].set(True) + score_mask |= score_mask.at[:, 0].set(True) # min tokens to keep - score_mask = score_mask.at[jax.ops.index[:, : self.min_tokens_to_keep]].set(True) + score_mask = score_mask.at[:, : self.min_tokens_to_keep].set(True) topk_next_scores = jnp.where(score_mask, topk_scores, mask_scores) next_scores = jax.lax.sort_key_val(topk_indices, topk_next_scores)[-1] @@ -207,7 +207,7 @@ class FlaxForcedBOSTokenLogitsProcessor(FlaxLogitsProcessor): apply_penalty = 1 - jnp.bool_(cur_len - 1) - scores = jnp.where(apply_penalty, new_scores.at[jax.ops.index[:, self.bos_token_id]].set(0), scores) + scores = jnp.where(apply_penalty, new_scores.at[:, self.bos_token_id].set(0), scores) return scores @@ -232,7 +232,7 @@ class FlaxForcedEOSTokenLogitsProcessor(FlaxLogitsProcessor): apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1) - scores = jnp.where(apply_penalty, new_scores.at[jax.ops.index[:, self.eos_token_id]].set(0), scores) + scores = jnp.where(apply_penalty, new_scores.at[:, self.eos_token_id].set(0), scores) return scores @@ -263,6 +263,6 @@ class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor): # create boolean flag to decide if min length penalty should be applied apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1) - scores = jnp.where(apply_penalty, scores.at[jax.ops.index[:, self.eos_token_id]].set(-float("inf")), scores) + scores = jnp.where(apply_penalty, scores.at[:, self.eos_token_id].set(-float("inf")), scores) return scores diff --git a/src/transformers/models/big_bird/modeling_flax_big_bird.py b/src/transformers/models/big_bird/modeling_flax_big_bird.py index 63f92023a4..d23d25494e 100644 --- a/src/transformers/models/big_bird/modeling_flax_big_bird.py +++ b/src/transformers/models/big_bird/modeling_flax_big_bird.py @@ -2124,7 +2124,7 @@ class FlaxBigBirdForQuestionAnswering(FlaxBigBirdPreTrainedModel): if token_type_ids is None: token_type_ids = (~logits_mask).astype("i4") logits_mask = jnp.expand_dims(logits_mask, axis=2) - logits_mask = logits_mask.at[jax.ops.index[:, 0]].set(False) + logits_mask = logits_mask.at[:, 0].set(False) # init input tensors if not passed if token_type_ids is None: diff --git a/src/transformers/models/marian/modeling_flax_marian.py b/src/transformers/models/marian/modeling_flax_marian.py index ce1424b374..9d8b44c5f9 100644 --- a/src/transformers/models/marian/modeling_flax_marian.py +++ b/src/transformers/models/marian/modeling_flax_marian.py @@ -1422,7 +1422,7 @@ class FlaxMarianMTModel(FlaxMarianPreTrainedModel): def _adapt_logits_for_beam_search(self, logits): """This function enforces the padding token never to be generated.""" - logits = logits.at[jax.ops.index[:, :, self.config.pad_token_id]].set(float("-inf")) + logits = logits.at[:, :, self.config.pad_token_id].set(float("-inf")) return logits def prepare_inputs_for_generation( diff --git a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py index 95a4072354..ae6707c051 100644 --- a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py @@ -964,8 +964,7 @@ class FlaxWav2Vec2Module(nn.Module): # these two operations makes sure that all values # before the output lengths indices are attended to - idx = jax.ops.index[jnp.arange(attention_mask.shape[0]), output_lengths - 1] - attention_mask = attention_mask.at[idx].set(1) + attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1) attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool") hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)