Replace all deprecated jax.ops operations with jnp's at (#16078)
* Replace all deprecated `jax.ops` operations with jnp's `at` * np to jnp scores * suggested changes
This commit is contained in:
@@ -35,7 +35,7 @@ model = FlaxGPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
|
|||||||
|
|
||||||
emb = jnp.zeros((50264, model.config.hidden_size))
|
emb = jnp.zeros((50264, model.config.hidden_size))
|
||||||
# update the first 50257 weights using pre-trained weights
|
# update the first 50257 weights using pre-trained weights
|
||||||
emb = jax.ops.index_update(emb, jax.ops.index[:50257, :], model.params["transformer"]["wte"]["embedding"])
|
emb = emb.at[jax.ops.index[:50257, :]].set(model.params["transformer"]["wte"]["embedding"])
|
||||||
params = model.params
|
params = model.params
|
||||||
params["transformer"]["wte"]["embedding"] = emb
|
params["transformer"]["wte"]["embedding"] = emb
|
||||||
|
|
||||||
|
|||||||
@@ -142,10 +142,11 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
|
|||||||
score_mask = cumulative_probs < self.top_p
|
score_mask = cumulative_probs < self.top_p
|
||||||
|
|
||||||
# include the token that is higher than top_p as well
|
# include the token that is higher than top_p as well
|
||||||
score_mask |= jax.ops.index_update(jnp.roll(score_mask, 1), jax.ops.index[:, 0], True)
|
score_mask = jnp.roll(score_mask, 1)
|
||||||
|
score_mask |= score_mask.at[jax.ops.index[:, 0]].set(True)
|
||||||
|
|
||||||
# min tokens to keep
|
# min tokens to keep
|
||||||
score_mask = jax.ops.index_update(score_mask, jax.ops.index[:, : self.min_tokens_to_keep], True)
|
score_mask = score_mask.at[jax.ops.index[:, : self.min_tokens_to_keep]].set(True)
|
||||||
|
|
||||||
topk_next_scores = jnp.where(score_mask, topk_scores, mask_scores)
|
topk_next_scores = jnp.where(score_mask, topk_scores, mask_scores)
|
||||||
next_scores = jax.lax.sort_key_val(topk_indices, topk_next_scores)[-1]
|
next_scores = jax.lax.sort_key_val(topk_indices, topk_next_scores)[-1]
|
||||||
@@ -184,7 +185,7 @@ class FlaxTopKLogitsWarper(FlaxLogitsWarper):
|
|||||||
topk_scores_flat = topk_scores.flatten()
|
topk_scores_flat = topk_scores.flatten()
|
||||||
topk_indices_flat = topk_indices.flatten() + shift
|
topk_indices_flat = topk_indices.flatten() + shift
|
||||||
|
|
||||||
next_scores_flat = jax.ops.index_update(next_scores_flat, topk_indices_flat, topk_scores_flat)
|
next_scores_flat = next_scores_flat.at[topk_indices_flat].set(topk_scores_flat)
|
||||||
next_scores = next_scores_flat.reshape(batch_size, vocab_size)
|
next_scores = next_scores_flat.reshape(batch_size, vocab_size)
|
||||||
return next_scores
|
return next_scores
|
||||||
|
|
||||||
@@ -206,9 +207,7 @@ class FlaxForcedBOSTokenLogitsProcessor(FlaxLogitsProcessor):
|
|||||||
|
|
||||||
apply_penalty = 1 - jnp.bool_(cur_len - 1)
|
apply_penalty = 1 - jnp.bool_(cur_len - 1)
|
||||||
|
|
||||||
scores = jnp.where(
|
scores = jnp.where(apply_penalty, new_scores.at[jax.ops.index[:, self.bos_token_id]].set(0), scores)
|
||||||
apply_penalty, jax.ops.index_update(new_scores, jax.ops.index[:, self.bos_token_id], 0), scores
|
|
||||||
)
|
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
@@ -233,9 +232,7 @@ class FlaxForcedEOSTokenLogitsProcessor(FlaxLogitsProcessor):
|
|||||||
|
|
||||||
apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1)
|
apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1)
|
||||||
|
|
||||||
scores = jnp.where(
|
scores = jnp.where(apply_penalty, new_scores.at[jax.ops.index[:, self.eos_token_id]].set(0), scores)
|
||||||
apply_penalty, jax.ops.index_update(new_scores, jax.ops.index[:, self.eos_token_id], 0), scores
|
|
||||||
)
|
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
@@ -266,8 +263,6 @@ class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor):
|
|||||||
# create boolean flag to decide if min length penalty should be applied
|
# create boolean flag to decide if min length penalty should be applied
|
||||||
apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1)
|
apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1)
|
||||||
|
|
||||||
scores = jnp.where(
|
scores = jnp.where(apply_penalty, scores.at[jax.ops.index[:, self.eos_token_id]].set(-float("inf")), scores)
|
||||||
apply_penalty, jax.ops.index_update(scores, jax.ops.index[:, self.eos_token_id], -float("inf")), scores
|
|
||||||
)
|
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|||||||
@@ -922,7 +922,7 @@ class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
# make sure initialization pass will work for FlaxBartForSequenceClassificationModule
|
# make sure initialization pass will work for FlaxBartForSequenceClassificationModule
|
||||||
input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id)
|
input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
decoder_input_ids = input_ids
|
decoder_input_ids = input_ids
|
||||||
decoder_attention_mask = jnp.ones_like(input_ids)
|
decoder_attention_mask = jnp.ones_like(input_ids)
|
||||||
|
|||||||
@@ -2124,7 +2124,7 @@ class FlaxBigBirdForQuestionAnswering(FlaxBigBirdPreTrainedModel):
|
|||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = (~logits_mask).astype("i4")
|
token_type_ids = (~logits_mask).astype("i4")
|
||||||
logits_mask = jnp.expand_dims(logits_mask, axis=2)
|
logits_mask = jnp.expand_dims(logits_mask, axis=2)
|
||||||
logits_mask = jax.ops.index_update(logits_mask, jax.ops.index[:, 0], False)
|
logits_mask = logits_mask.at[jax.ops.index[:, 0]].set(False)
|
||||||
|
|
||||||
# init input tensors if not passed
|
# init input tensors if not passed
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
|
|||||||
@@ -897,7 +897,7 @@ class FlaxBlenderbotPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
# make sure initialization pass will work for FlaxBlenderbotForSequenceClassificationModule
|
# make sure initialization pass will work for FlaxBlenderbotForSequenceClassificationModule
|
||||||
input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id)
|
input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
decoder_input_ids = input_ids
|
decoder_input_ids = input_ids
|
||||||
decoder_attention_mask = jnp.ones_like(input_ids)
|
decoder_attention_mask = jnp.ones_like(input_ids)
|
||||||
|
|||||||
@@ -895,7 +895,7 @@ class FlaxBlenderbotSmallPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
# make sure initialization pass will work for FlaxBlenderbotSmallForSequenceClassificationModule
|
# make sure initialization pass will work for FlaxBlenderbotSmallForSequenceClassificationModule
|
||||||
input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id)
|
input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
decoder_input_ids = input_ids
|
decoder_input_ids = input_ids
|
||||||
decoder_attention_mask = jnp.ones_like(input_ids)
|
decoder_attention_mask = jnp.ones_like(input_ids)
|
||||||
|
|||||||
@@ -892,7 +892,7 @@ class FlaxMarianPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
# make sure initialization pass will work for FlaxMarianForSequenceClassificationModule
|
# make sure initialization pass will work for FlaxMarianForSequenceClassificationModule
|
||||||
input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id)
|
input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
decoder_input_ids = input_ids
|
decoder_input_ids = input_ids
|
||||||
decoder_attention_mask = jnp.ones_like(input_ids)
|
decoder_attention_mask = jnp.ones_like(input_ids)
|
||||||
@@ -1422,7 +1422,7 @@ class FlaxMarianMTModel(FlaxMarianPreTrainedModel):
|
|||||||
|
|
||||||
def _adapt_logits_for_beam_search(self, logits):
|
def _adapt_logits_for_beam_search(self, logits):
|
||||||
"""This function enforces the padding token never to be generated."""
|
"""This function enforces the padding token never to be generated."""
|
||||||
logits = jax.ops.index_update(logits, jax.ops.index[:, :, self.config.pad_token_id], float("-inf"))
|
logits = logits.at[jax.ops.index[:, :, self.config.pad_token_id]].set(float("-inf"))
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
|
|||||||
@@ -960,7 +960,7 @@ class FlaxMBartPreTrainedModel(FlaxPreTrainedModel):
|
|||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
# make sure initialization pass will work for FlaxMBartForSequenceClassificationModule
|
# make sure initialization pass will work for FlaxMBartForSequenceClassificationModule
|
||||||
input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id)
|
input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
decoder_input_ids = input_ids
|
decoder_input_ids = input_ids
|
||||||
decoder_attention_mask = jnp.ones_like(input_ids)
|
decoder_attention_mask = jnp.ones_like(input_ids)
|
||||||
|
|||||||
@@ -964,9 +964,8 @@ class FlaxWav2Vec2Module(nn.Module):
|
|||||||
|
|
||||||
# these two operations makes sure that all values
|
# these two operations makes sure that all values
|
||||||
# before the output lengths indices are attended to
|
# before the output lengths indices are attended to
|
||||||
attention_mask = jax.ops.index_update(
|
idx = jax.ops.index[jnp.arange(attention_mask.shape[0]), output_lengths - 1]
|
||||||
attention_mask, jax.ops.index[jnp.arange(attention_mask.shape[0]), output_lengths - 1], 1
|
attention_mask = attention_mask.at[idx].set(1)
|
||||||
)
|
|
||||||
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
|
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
|
||||||
|
|
||||||
hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
|
hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
|
||||||
@@ -1038,7 +1037,8 @@ class FlaxWav2Vec2Module(nn.Module):
|
|||||||
|
|
||||||
attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
|
attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
|
||||||
# these two operations makes sure that all values before the output lengths idxs are attended to
|
# these two operations makes sure that all values before the output lengths idxs are attended to
|
||||||
attention_mask = attention_mask.at[(jnp.arange(attention_mask.shape[0]), output_lengths - 1)].set(1)
|
idx = (jnp.arange(attention_mask.shape[0]), output_lengths - 1)
|
||||||
|
attention_mask = attention_mask.at[idx].set(1)
|
||||||
attention_mask = jnp.flip(jnp.flip(attention_mask, axis=-1).cumsum(axis=-1), axis=-1)
|
attention_mask = jnp.flip(jnp.flip(attention_mask, axis=-1).cumsum(axis=-1), axis=-1)
|
||||||
|
|
||||||
attention_mask = jnp.array(attention_mask, dtype=bool)
|
attention_mask = jnp.array(attention_mask, dtype=bool)
|
||||||
|
|||||||
@@ -131,7 +131,7 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_
|
|||||||
Shift input ids one token to the right.
|
Shift input ids one token to the right.
|
||||||
"""
|
"""
|
||||||
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
|
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
|
||||||
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
|
shifted_input_ids = shifted_input_ids.at[(..., 0)].set(decoder_start_token_id)
|
||||||
# replace possible -100 values in labels by `pad_token_id`
|
# replace possible -100 values in labels by `pad_token_id`
|
||||||
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
|
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
|
||||||
|
|
||||||
|
|||||||
@@ -1330,7 +1330,7 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_
|
|||||||
Shift input ids one token to the right.
|
Shift input ids one token to the right.
|
||||||
"""
|
"""
|
||||||
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
|
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
|
||||||
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
|
shifted_input_ids = shifted_input_ids.at[(..., 0)].set(decoder_start_token_id)
|
||||||
# replace possible -100 values in labels by `pad_token_id`
|
# replace possible -100 values in labels by `pad_token_id`
|
||||||
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
|
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
|
||||||
|
|
||||||
@@ -2040,7 +2040,7 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
|||||||
# init input tensors
|
# init input tensors
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||||
# make sure initialization pass will work for Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule
|
# make sure initialization pass will work for Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule
|
||||||
input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id)
|
input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
attention_mask = jnp.ones_like(input_ids)
|
||||||
decoder_input_ids = input_ids
|
decoder_input_ids = input_ids
|
||||||
decoder_attention_mask = jnp.ones_like(input_ids)
|
decoder_attention_mask = jnp.ones_like(input_ids)
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ if is_flax_available():
|
|||||||
@require_flax
|
@require_flax
|
||||||
class LogitsProcessorTest(unittest.TestCase):
|
class LogitsProcessorTest(unittest.TestCase):
|
||||||
def _get_uniform_logits(self, batch_size: int, length: int):
|
def _get_uniform_logits(self, batch_size: int, length: int):
|
||||||
scores = np.ones((batch_size, length)) / length
|
scores = jnp.ones((batch_size, length)) / length
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def test_temperature_dist_warper(self):
|
def test_temperature_dist_warper(self):
|
||||||
@@ -51,8 +51,8 @@ class LogitsProcessorTest(unittest.TestCase):
|
|||||||
scores = self._get_uniform_logits(batch_size=2, length=length)
|
scores = self._get_uniform_logits(batch_size=2, length=length)
|
||||||
|
|
||||||
# tweak scores to not be uniform anymore
|
# tweak scores to not be uniform anymore
|
||||||
scores[1, 5] = (1 / length) + 0.1 # peak, 1st batch
|
scores = scores.at[1, 5].set((1 / length) + 0.1) # peak, 1st batch
|
||||||
scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch
|
scores = scores.at[1, 10].set((1 / length) - 0.4) # valley, 1st batch
|
||||||
|
|
||||||
# compute softmax
|
# compute softmax
|
||||||
probs = jax.nn.softmax(scores, axis=-1)
|
probs = jax.nn.softmax(scores, axis=-1)
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ from transformers.testing_utils import is_pt_flax_cross_test, require_flax
|
|||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from jax import jit
|
from jax import jit
|
||||||
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
|
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
|
||||||
@@ -219,7 +218,7 @@ class FlaxGenerationTesterMixin:
|
|||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||||
|
|
||||||
# pad attention mask on the left
|
# pad attention mask on the left
|
||||||
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
|
attention_mask = attention_mask.at[(0, 0)].set(0)
|
||||||
|
|
||||||
config.do_sample = False
|
config.do_sample = False
|
||||||
config.max_length = max_length
|
config.max_length = max_length
|
||||||
@@ -239,7 +238,7 @@ class FlaxGenerationTesterMixin:
|
|||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||||
|
|
||||||
# pad attention mask on the left
|
# pad attention mask on the left
|
||||||
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
|
attention_mask = attention_mask.at[(0, 0)].set(0)
|
||||||
|
|
||||||
config.do_sample = True
|
config.do_sample = True
|
||||||
config.max_length = max_length
|
config.max_length = max_length
|
||||||
@@ -259,7 +258,7 @@ class FlaxGenerationTesterMixin:
|
|||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||||
|
|
||||||
# pad attention mask on the left
|
# pad attention mask on the left
|
||||||
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
|
attention_mask = attention_mask.at[(0, 0)].set(0)
|
||||||
|
|
||||||
config.num_beams = 2
|
config.num_beams = 2
|
||||||
config.max_length = max_length
|
config.max_length = max_length
|
||||||
|
|||||||
Reference in New Issue
Block a user