[JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* (#18361)
* [JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* * fix double tree_util
This commit is contained in:
@@ -104,7 +104,7 @@ class DataCollator:
|
||||
|
||||
def __call__(self, batch):
|
||||
batch = self.collate_fn(batch)
|
||||
batch = jax.tree_map(shard, batch)
|
||||
batch = jax.tree_util.tree_map(shard, batch)
|
||||
return batch
|
||||
|
||||
def collate_fn(self, features):
|
||||
|
||||
@@ -608,9 +608,9 @@ if __name__ == "__main__":
|
||||
|
||||
# normalize eval metrics
|
||||
eval_metrics = get_metrics(eval_metrics)
|
||||
eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
|
||||
eval_metrics = jax.tree_util.tree_map(jnp.sum, eval_metrics)
|
||||
eval_normalizer = eval_metrics.pop("normalizer")
|
||||
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
||||
eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
||||
|
||||
# Update progress bar
|
||||
steps.desc = (
|
||||
@@ -624,7 +624,7 @@ if __name__ == "__main__":
|
||||
|
||||
# save checkpoint after each epoch and push checkpoint to the hub
|
||||
if jax.process_index() == 0:
|
||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
||||
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
||||
model.save_pretrained(
|
||||
training_args.output_dir,
|
||||
params=params,
|
||||
|
||||
@@ -551,7 +551,7 @@ def main():
|
||||
# normalize eval metrics
|
||||
eval_metrics = get_metrics(eval_metrics)
|
||||
|
||||
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
||||
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
|
||||
|
||||
# Print metrics and update progress bar
|
||||
eval_step_progress_bar.close()
|
||||
|
||||
@@ -481,7 +481,7 @@ def main():
|
||||
param_spec = set_partitions(unfreeze(model.params))
|
||||
|
||||
# Get the PyTree for opt_state, we don't actually initialize the opt_state yet.
|
||||
params_shapes = jax.tree_map(lambda x: x.shape, model.params)
|
||||
params_shapes = jax.tree_util.tree_map(lambda x: x.shape, model.params)
|
||||
state_shapes = jax.eval_shape(get_initial_state, params_shapes)
|
||||
|
||||
# get PartitionSpec for opt_state, this is very specific to adamw
|
||||
@@ -492,7 +492,7 @@ def main():
|
||||
return param_spec
|
||||
return None
|
||||
|
||||
opt_state_spec, param_spec = jax.tree_map(
|
||||
opt_state_spec, param_spec = jax.tree_util.tree_map(
|
||||
get_opt_spec, state_shapes, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
|
||||
)
|
||||
|
||||
@@ -506,7 +506,7 @@ def main():
|
||||
|
||||
# hack: move the inital params to CPU to free up device memory
|
||||
# TODO: allow loading weights on CPU in pre-trained model
|
||||
model.params = jax.tree_map(lambda x: np.asarray(x), model.params)
|
||||
model.params = jax.tree_util.tree_map(lambda x: np.asarray(x), model.params)
|
||||
|
||||
# mesh defination
|
||||
mesh_devices = np.array(jax.devices()).reshape(1, jax.local_device_count())
|
||||
@@ -636,7 +636,7 @@ def main():
|
||||
|
||||
# normalize eval metrics
|
||||
eval_metrics = stack_forest(eval_metrics)
|
||||
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
||||
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
|
||||
|
||||
try:
|
||||
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
|
||||
|
||||
@@ -591,7 +591,7 @@ def main():
|
||||
|
||||
# get eval metrics
|
||||
eval_metrics = get_metrics(eval_metrics)
|
||||
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
||||
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
|
||||
|
||||
# Update progress bar
|
||||
epochs.write(
|
||||
@@ -606,7 +606,7 @@ def main():
|
||||
|
||||
# save checkpoint after each epoch and push checkpoint to the hub
|
||||
if jax.process_index() == 0:
|
||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
||||
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
||||
model.save_pretrained(training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub)
|
||||
|
||||
|
||||
|
||||
@@ -674,9 +674,9 @@ if __name__ == "__main__":
|
||||
eval_metrics.append(metrics)
|
||||
|
||||
eval_metrics_np = get_metrics(eval_metrics)
|
||||
eval_metrics_np = jax.tree_map(jnp.sum, eval_metrics_np)
|
||||
eval_metrics_np = jax.tree_util.tree_map(jnp.sum, eval_metrics_np)
|
||||
eval_normalizer = eval_metrics_np.pop("normalizer")
|
||||
eval_summary = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics_np)
|
||||
eval_summary = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics_np)
|
||||
|
||||
# Update progress bar
|
||||
epochs.desc = (
|
||||
|
||||
Reference in New Issue
Block a user