[JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* (#18361)
* [JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* * fix double tree_util
This commit is contained in:
@@ -481,7 +481,7 @@ def main():
|
||||
param_spec = set_partitions(unfreeze(model.params))
|
||||
|
||||
# Get the PyTree for opt_state, we don't actually initialize the opt_state yet.
|
||||
params_shapes = jax.tree_map(lambda x: x.shape, model.params)
|
||||
params_shapes = jax.tree_util.tree_map(lambda x: x.shape, model.params)
|
||||
state_shapes = jax.eval_shape(get_initial_state, params_shapes)
|
||||
|
||||
# get PartitionSpec for opt_state, this is very specific to adamw
|
||||
@@ -492,7 +492,7 @@ def main():
|
||||
return param_spec
|
||||
return None
|
||||
|
||||
opt_state_spec, param_spec = jax.tree_map(
|
||||
opt_state_spec, param_spec = jax.tree_util.tree_map(
|
||||
get_opt_spec, state_shapes, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
|
||||
)
|
||||
|
||||
@@ -506,7 +506,7 @@ def main():
|
||||
|
||||
# hack: move the inital params to CPU to free up device memory
|
||||
# TODO: allow loading weights on CPU in pre-trained model
|
||||
model.params = jax.tree_map(lambda x: np.asarray(x), model.params)
|
||||
model.params = jax.tree_util.tree_map(lambda x: np.asarray(x), model.params)
|
||||
|
||||
# mesh defination
|
||||
mesh_devices = np.array(jax.devices()).reshape(1, jax.local_device_count())
|
||||
@@ -636,7 +636,7 @@ def main():
|
||||
|
||||
# normalize eval metrics
|
||||
eval_metrics = stack_forest(eval_metrics)
|
||||
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
||||
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
|
||||
|
||||
try:
|
||||
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
|
||||
|
||||
Reference in New Issue
Block a user