Replace all deprecated jax.ops operations with jnp's at (#16078)

* Replace all deprecated `jax.ops` operations with jnp's `at`

* np to jnp scores

* suggested changes
This commit is contained in:
Sanchit Gandhi
2022-03-16 09:08:55 +00:00
committed by GitHub
parent c2dc89be62
commit ee27b3d7df
13 changed files with 28 additions and 34 deletions

View File

@@ -35,7 +35,7 @@ model = FlaxGPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
emb = jnp.zeros((50264, model.config.hidden_size))
# update the first 50257 weights using pre-trained weights
emb = jax.ops.index_update(emb, jax.ops.index[:50257, :], model.params["transformer"]["wte"]["embedding"])
emb = emb.at[jax.ops.index[:50257, :]].set(model.params["transformer"]["wte"]["embedding"])
params = model.params
params["transformer"]["wte"]["embedding"] = emb