Fix t5 shard on TPU Pods (#16527)

* Fix t5 shard on TPU Pods

The current script doesn't work properly on a TPU pod because the global batch is not divided correctly per host.
This pull request fixes this issue by dividing the global batch to each host before it is shared on each host.

* fix style

Co-authored-by: ahmed-elnaggar <ahmed.elnaggar@allianz.com>
This commit is contained in:
Ahmed Elnaggar
2022-04-11 16:45:20 +02:00
committed by GitHub
parent 2831826bc6
commit 5e68675755

View File

@@ -746,6 +746,9 @@ def main():
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
num_of_hosts = jax.process_count()
current_host_idx = jax.process_index()
# Create learning rate schedule
warmup_fn = optax.linear_schedule(
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
@@ -861,8 +864,13 @@ def main():
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples)
local_host_model_inputs = {
key: np.split(model_inputs.data[key], num_of_hosts, axis=0)[current_host_idx]
for key, value in model_inputs.data.items()
}
# Model forward
model_inputs = shard(model_inputs.data)
model_inputs = shard(local_host_model_inputs)
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
train_metrics.append(train_metric)