Replace all deprecated jax.ops operations with jnp's at (#16078)

* Replace all deprecated `jax.ops` operations with jnp's `at`

* np to jnp scores

* suggested changes
This commit is contained in:
Sanchit Gandhi
2022-03-16 09:08:55 +00:00
committed by GitHub
parent c2dc89be62
commit ee27b3d7df
13 changed files with 28 additions and 34 deletions

View File

@@ -1330,7 +1330,7 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_
Shift input ids one token to the right.
"""
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
shifted_input_ids = shifted_input_ids.at[(..., 0)].set(decoder_start_token_id)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
@@ -2040,7 +2040,7 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
# make sure initialization pass will work for Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule
input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id)
input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)
attention_mask = jnp.ones_like(input_ids)
decoder_input_ids = input_ids
decoder_attention_mask = jnp.ones_like(input_ids)