Replace all deprecated jax.ops operations with jnp's at (#16078)

* Replace all deprecated `jax.ops` operations with jnp's `at`

* np to jnp scores

* suggested changes
This commit is contained in:
Sanchit Gandhi
2022-03-16 09:08:55 +00:00
committed by GitHub
parent c2dc89be62
commit ee27b3d7df
13 changed files with 28 additions and 34 deletions

View File

@@ -41,7 +41,7 @@ if is_flax_available():
@require_flax
class LogitsProcessorTest(unittest.TestCase):
def _get_uniform_logits(self, batch_size: int, length: int):
scores = np.ones((batch_size, length)) / length
scores = jnp.ones((batch_size, length)) / length
return scores
def test_temperature_dist_warper(self):
@@ -51,8 +51,8 @@ class LogitsProcessorTest(unittest.TestCase):
scores = self._get_uniform_logits(batch_size=2, length=length)
# tweak scores to not be uniform anymore
scores[1, 5] = (1 / length) + 0.1 # peak, 1st batch
scores[1, 10] = (1 / length) - 0.4 # valley, 1st batch
scores = scores.at[1, 5].set((1 / length) + 0.1) # peak, 1st batch
scores = scores.at[1, 10].set((1 / length) - 0.4) # valley, 1st batch
# compute softmax
probs = jax.nn.softmax(scores, axis=-1)