Replace all deprecated jax.ops operations with jnp's at (#16078)
* Replace all deprecated `jax.ops` operations with jnp's `at` * np to jnp scores * suggested changes
This commit is contained in:
@@ -1330,7 +1330,7 @@ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_
|
||||
Shift input ids one token to the right.
|
||||
"""
|
||||
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
|
||||
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
|
||||
shifted_input_ids = shifted_input_ids.at[(..., 0)].set(decoder_start_token_id)
|
||||
# replace possible -100 values in labels by `pad_token_id`
|
||||
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
|
||||
|
||||
@@ -2040,7 +2040,7 @@ class Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel(FlaxPreTrainedMode
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
# make sure initialization pass will work for Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule
|
||||
input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id)
|
||||
input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
decoder_input_ids = input_ids
|
||||
decoder_attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
Reference in New Issue
Block a user