Replace all deprecated jax.ops operations with jnp's at (#16078)
* Replace all deprecated `jax.ops` operations with jnp's `at` * np to jnp scores * suggested changes
This commit is contained in:
@@ -41,7 +41,7 @@ if is_flax_available():
|
||||
@require_flax
|
||||
class LogitsProcessorTest(unittest.TestCase):
|
||||
def _get_uniform_logits(self, batch_size: int, length: int):
|
||||
scores = np.ones((batch_size, length)) / length
|
||||
scores = jnp.ones((batch_size, length)) / length
|
||||
return scores
|
||||
|
||||
def test_temperature_dist_warper(self):
|
||||
@@ -51,8 +51,8 @@ class LogitsProcessorTest(unittest.TestCase):
|
||||
scores = self._get_uniform_logits(batch_size=2, length=length)
|
||||
|
||||
# tweak scores to not be uniform anymore
|
||||
scores[1, 5] = (1 / length) + 0.1 # peak, 1st batch
|
||||
scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch
|
||||
scores = scores.at[1, 5].set((1 / length) + 0.1) # peak, 1st batch
|
||||
scores = scores.at[1, 10].set((1 / length) - 0.4) # valley, 1st batch
|
||||
|
||||
# compute softmax
|
||||
probs = jax.nn.softmax(scores, axis=-1)
|
||||
|
||||
@@ -24,7 +24,6 @@ from transformers.testing_utils import is_pt_flax_cross_test, require_flax
|
||||
if is_flax_available():
|
||||
import os
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import jit
|
||||
from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model
|
||||
@@ -219,7 +218,7 @@ class FlaxGenerationTesterMixin:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# pad attention mask on the left
|
||||
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
|
||||
attention_mask = attention_mask.at[(0, 0)].set(0)
|
||||
|
||||
config.do_sample = False
|
||||
config.max_length = max_length
|
||||
@@ -239,7 +238,7 @@ class FlaxGenerationTesterMixin:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# pad attention mask on the left
|
||||
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
|
||||
attention_mask = attention_mask.at[(0, 0)].set(0)
|
||||
|
||||
config.do_sample = True
|
||||
config.max_length = max_length
|
||||
@@ -259,7 +258,7 @@ class FlaxGenerationTesterMixin:
|
||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
||||
|
||||
# pad attention mask on the left
|
||||
attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
|
||||
attention_mask = attention_mask.at[(0, 0)].set(0)
|
||||
|
||||
config.num_beams = 2
|
||||
config.max_length = max_length
|
||||
|
||||
Reference in New Issue
Block a user