From 5e68675755e3d45aade61950764e361ae82c7022 Mon Sep 17 00:00:00 2001 From: Ahmed Elnaggar Date: Mon, 11 Apr 2022 16:45:20 +0200 Subject: [PATCH] Fix t5 shard on TPU Pods (#16527) * Fix t5 shard on TPU Pods The current script doesn't work properly on a TPU pod because the global batch is not divided correctly per host. This pull request fixes this issue by dividing the global batch to each host before it is shared on each host. * fix style Co-authored-by: ahmed-elnaggar --- examples/flax/language-modeling/run_t5_mlm_flax.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index 5b1067cd99..368ecf0e61 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -746,6 +746,9 @@ def main(): num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs + num_of_hosts = jax.process_count() + current_host_idx = jax.process_index() + # Create learning rate schedule warmup_fn = optax.linear_schedule( init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps @@ -861,8 +864,13 @@ def main(): samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx] model_inputs = data_collator(samples) + local_host_model_inputs = { + key: np.split(model_inputs.data[key], num_of_hosts, axis=0)[current_host_idx] + for key, value in model_inputs.data.items() + } + # Model forward - model_inputs = shard(model_inputs.data) + model_inputs = shard(local_host_model_inputs) state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs) train_metrics.append(train_metric)