Fix t5 shard on TPU Pods (#16527)
* Fix t5 shard on TPU Pods The current script doesn't work properly on a TPU pod because the global batch is not divided correctly per host. This pull request fixes this issue by dividing the global batch to each host before it is shared on each host. * fix style Co-authored-by: ahmed-elnaggar <ahmed.elnaggar@allianz.com>
This commit is contained in:
@@ -746,6 +746,9 @@ def main():
|
|||||||
|
|
||||||
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
|
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
|
||||||
|
|
||||||
|
num_of_hosts = jax.process_count()
|
||||||
|
current_host_idx = jax.process_index()
|
||||||
|
|
||||||
# Create learning rate schedule
|
# Create learning rate schedule
|
||||||
warmup_fn = optax.linear_schedule(
|
warmup_fn = optax.linear_schedule(
|
||||||
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
|
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
|
||||||
@@ -861,8 +864,13 @@ def main():
|
|||||||
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
||||||
model_inputs = data_collator(samples)
|
model_inputs = data_collator(samples)
|
||||||
|
|
||||||
|
local_host_model_inputs = {
|
||||||
|
key: np.split(model_inputs.data[key], num_of_hosts, axis=0)[current_host_idx]
|
||||||
|
for key, value in model_inputs.data.items()
|
||||||
|
}
|
||||||
|
|
||||||
# Model forward
|
# Model forward
|
||||||
model_inputs = shard(model_inputs.data)
|
model_inputs = shard(local_host_model_inputs)
|
||||||
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
||||||
train_metrics.append(train_metric)
|
train_metrics.append(train_metric)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user