[JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* (#18361)

* [JAX] Replace all jax.tree_* calls with jax.tree_util.tree_*

* fix double tree_util
This commit is contained in:
Sanchit Gandhi
2022-09-09 14:18:56 +01:00
committed by GitHub
parent 22f7218560
commit e6f221c8d4
17 changed files with 49 additions and 49 deletions

View File

@@ -104,7 +104,7 @@ class DataCollator:
def __call__(self, batch):
batch = self.collate_fn(batch)
batch = jax.tree_map(shard, batch)
batch = jax.tree_util.tree_map(shard, batch)
return batch
def collate_fn(self, features):

View File

@@ -608,9 +608,9 @@ if __name__ == "__main__":
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.sum, eval_metrics)
eval_normalizer = eval_metrics.pop("normalizer")
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
# Update progress bar
steps.desc = (
@@ -624,7 +624,7 @@ if __name__ == "__main__":
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
model.save_pretrained(
training_args.output_dir,
params=params,

View File

@@ -551,7 +551,7 @@ def main():
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
# Print metrics and update progress bar
eval_step_progress_bar.close()

View File

@@ -481,7 +481,7 @@ def main():
param_spec = set_partitions(unfreeze(model.params))
# Get the PyTree for opt_state, we don't actually initialize the opt_state yet.
params_shapes = jax.tree_map(lambda x: x.shape, model.params)
params_shapes = jax.tree_util.tree_map(lambda x: x.shape, model.params)
state_shapes = jax.eval_shape(get_initial_state, params_shapes)
# get PartitionSpec for opt_state, this is very specific to adamw
@@ -492,7 +492,7 @@ def main():
return param_spec
return None
opt_state_spec, param_spec = jax.tree_map(
opt_state_spec, param_spec = jax.tree_util.tree_map(
get_opt_spec, state_shapes, is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState))
)
@@ -506,7 +506,7 @@ def main():
# hack: move the inital params to CPU to free up device memory
# TODO: allow loading weights on CPU in pre-trained model
model.params = jax.tree_map(lambda x: np.asarray(x), model.params)
model.params = jax.tree_util.tree_map(lambda x: np.asarray(x), model.params)
# mesh defination
mesh_devices = np.array(jax.devices()).reshape(1, jax.local_device_count())
@@ -636,7 +636,7 @@ def main():
# normalize eval metrics
eval_metrics = stack_forest(eval_metrics)
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
try:
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])

View File

@@ -591,7 +591,7 @@ def main():
# get eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
# Update progress bar
epochs.write(
@@ -606,7 +606,7 @@ def main():
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
model.save_pretrained(training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub)

View File

@@ -674,9 +674,9 @@ if __name__ == "__main__":
eval_metrics.append(metrics)
eval_metrics_np = get_metrics(eval_metrics)
eval_metrics_np = jax.tree_map(jnp.sum, eval_metrics_np)
eval_metrics_np = jax.tree_util.tree_map(jnp.sum, eval_metrics_np)
eval_normalizer = eval_metrics_np.pop("normalizer")
eval_summary = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics_np)
eval_summary = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics_np)
# Update progress bar
epochs.desc = (