Replace all deprecated jax.ops operations with jnp's at (#16078)

* Replace all deprecated `jax.ops` operations with jnp's `at`

* np to jnp scores

* suggested changes
This commit is contained in:
Sanchit Gandhi
2022-03-16 09:08:55 +00:00
committed by GitHub
parent c2dc89be62
commit ee27b3d7df
13 changed files with 28 additions and 34 deletions

View File

@@ -24,7 +24,6 @@ from transformers.testing_utils import is_pt_flax_cross_test, require_flax
if is_flax_available():
import os
import jax
import jax.numpy as jnp
from jax import jit
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
@@ -219,7 +218,7 @@ class FlaxGenerationTesterMixin:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# pad attention mask on the left
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
attention_mask = attention_mask.at[(0, 0)].set(0)
config.do_sample = False
config.max_length = max_length
@@ -239,7 +238,7 @@ class FlaxGenerationTesterMixin:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# pad attention mask on the left
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
attention_mask = attention_mask.at[(0, 0)].set(0)
config.do_sample = True
config.max_length = max_length
@@ -259,7 +258,7 @@ class FlaxGenerationTesterMixin:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
# pad attention mask on the left
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
attention_mask = attention_mask.at[(0, 0)].set(0)
config.num_beams = 2
config.max_length = max_length