[Flax] Add T5 pretraining script (#12355)
* fix_torch_device_generate_test * remove @ * add length computatan * finish masking * finish * upload * fix some bugs * finish * fix dependency table * correct tensorboard * Apply suggestions from code review * correct processing * slight change init * correct some more mistakes * apply suggestions * improve readme * fix indent * Apply suggestions from code review Co-authored-by: SaulLu <55560583+SaulLu@users.noreply.github.com> * correct tokenizer * finish * finish * finish * finish Co-authored-by: Patrick von Platen <patrick@huggingface.co> Co-authored-by: SaulLu <55560583+SaulLu@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
e277074889
commit
31c3e7e75b
@@ -582,12 +582,12 @@ if __name__ == "__main__":
|
||||
# Replicate the train state on each device
|
||||
state = jax_utils.replicate(state)
|
||||
|
||||
train_metrics = []
|
||||
train_time = 0
|
||||
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||
for epoch in epochs:
|
||||
# ======================== Training ================================
|
||||
train_start = time.time()
|
||||
train_metrics = []
|
||||
|
||||
# Create sampling rng
|
||||
rng, input_rng = jax.random.split(rng)
|
||||
|
||||
Reference in New Issue
Block a user