From 2e4082364e4bd001f7933d81b3f75548704f79d7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 6 Aug 2021 11:21:37 +0200 Subject: [PATCH] [Flax T5] Speed up t5 training (#13012) * fix_torch_device_generate_test * remove @ * update * up * fix * remove f-stings * correct readme * up Co-authored-by: Patrick von Platen --- examples/flax/language-modeling/README.md | 10 +++++----- examples/flax/language-modeling/run_t5_mlm_flax.py | 5 +++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/flax/language-modeling/README.md b/examples/flax/language-modeling/README.md index cd82a6b9bc..9bbcc7e895 100644 --- a/examples/flax/language-modeling/README.md +++ b/examples/flax/language-modeling/README.md @@ -373,15 +373,15 @@ Next we can run the example script to pretrain the model: --weight_decay="0.001" \ --warmup_steps="2000" \ --overwrite_output_dir \ - --logging_steps="100" \ - --save_steps="1000" \ - --eval_steps="1000" \ + --logging_steps="500" \ + --save_steps="10000" \ + --eval_steps="2500" \ --push_to_hub ``` Training should converge at a loss and accuracy -of 2.2 and 58.0 respectively after 2 epochs on a single TPUv3-8. -This should take around 24 hours. +of 2.36 and 57.0 respectively after 3 epochs on a single TPUv3-8. +This should take around 4.5 hours. Training statistics can be accessed on directly on the 🤗 [hub](https://huggingface.co/patrickvonplaten/t5-base-norwegian/tensorboard) ## Runtime evaluation diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index e50381e4d9..14ef8eb524 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -353,7 +353,8 @@ class FlaxDataCollatorForT5MLM: np.random.shuffle(mask_indices) first_in_segment = np.pad(mask_indices, [[1, 0]]) segment_id = np.cumsum(first_in_segment) - segment_length = np.asarray(jax.ops.segment_sum(np.ones_like(segment_id), segment_id)) + # count length of sub segments assuming that list is sorted + _, segment_length = np.unique(segment_id, return_counts=True) return segment_length noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans) @@ -720,7 +721,7 @@ if __name__ == "__main__": state = jax_utils.replicate(state) train_time = 0 - epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) + epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time()