remove jax.ops.index (#16220)
This commit is contained in:
@@ -35,7 +35,7 @@ model = FlaxGPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
|
||||
|
||||
emb = jnp.zeros((50264, model.config.hidden_size))
|
||||
# update the first 50257 weights using pre-trained weights
|
||||
emb = emb.at[jax.ops.index[:50257, :]].set(model.params["transformer"]["wte"]["embedding"])
|
||||
emb = emb.at[:50257, :].set(model.params["transformer"]["wte"]["embedding"])
|
||||
params = model.params
|
||||
params["transformer"]["wte"]["embedding"] = emb
|
||||
|
||||
|
||||
Reference in New Issue
Block a user