[Flax] Correct flax training scripts (#12514)
* fix_torch_device_generate_test * remove @ * add logging steps * correct training scripts * correct readme * correct
This commit is contained in:
committed by
GitHub
parent
ea55675024
commit
bb4ac2b5a8
@@ -56,22 +56,6 @@ from transformers import (
|
||||
)
|
||||
|
||||
|
||||
# Cache the result
|
||||
has_tensorboard = is_tensorboard_available()
|
||||
if has_tensorboard:
|
||||
try:
|
||||
from flax.metrics.tensorboard import SummaryWriter
|
||||
except ImportError as ie:
|
||||
has_tensorboard = False
|
||||
print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
|
||||
|
||||
else:
|
||||
print(
|
||||
"Unable to display metrics through TensorBoard because the package is not installed: "
|
||||
"Please run pip install tensorboard to enable."
|
||||
)
|
||||
|
||||
|
||||
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
@@ -269,7 +253,7 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar
|
||||
return batch_idx
|
||||
|
||||
|
||||
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
||||
def write_train_metric(summary_writer, train_metrics, train_time, step):
|
||||
summary_writer.scalar("train_time", train_time, step)
|
||||
|
||||
train_metrics = get_metrics(train_metrics)
|
||||
@@ -278,6 +262,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
||||
for i, val in enumerate(vals):
|
||||
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
||||
|
||||
|
||||
def write_eval_metric(summary_writer, eval_metrics, step):
|
||||
for metric_name, value in eval_metrics.items():
|
||||
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
||||
|
||||
@@ -315,10 +301,6 @@ if __name__ == "__main__":
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
|
||||
# Set the verbosity to info of the Transformers logger (on main process only):
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
@@ -471,8 +453,22 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
# Enable tensorboard only on the master node
|
||||
has_tensorboard = is_tensorboard_available()
|
||||
if has_tensorboard and jax.process_index() == 0:
|
||||
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
||||
try:
|
||||
from flax.metrics.tensorboard import SummaryWriter
|
||||
|
||||
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
||||
except ImportError as ie:
|
||||
has_tensorboard = False
|
||||
logger.warning(
|
||||
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Unable to display metrics through TensorBoard because the package is not installed: "
|
||||
"Please run pip install tensorboard to enable."
|
||||
)
|
||||
|
||||
# Data collator
|
||||
# This one will take care of randomly masking the tokens.
|
||||
@@ -601,7 +597,7 @@ if __name__ == "__main__":
|
||||
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
|
||||
|
||||
# Gather the indexes for creating the batch and do a training step
|
||||
for i, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
|
||||
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
|
||||
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
||||
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
||||
|
||||
@@ -610,11 +606,20 @@ if __name__ == "__main__":
|
||||
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
||||
train_metrics.append(train_metric)
|
||||
|
||||
train_time += time.time() - train_start
|
||||
cur_step = epoch * num_train_samples + step
|
||||
|
||||
epochs.write(
|
||||
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
||||
)
|
||||
if cur_step % training_args.logging_steps and cur_step > 0:
|
||||
# Save metrics
|
||||
train_metric = jax_utils.unreplicate(train_metric)
|
||||
train_time += time.time() - train_start
|
||||
if has_tensorboard and jax.process_index() == 0:
|
||||
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
||||
|
||||
epochs.write(
|
||||
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
||||
)
|
||||
|
||||
train_metrics = []
|
||||
|
||||
# ======================== Evaluating ==============================
|
||||
num_eval_samples = len(tokenized_datasets["validation"])
|
||||
@@ -645,7 +650,7 @@ if __name__ == "__main__":
|
||||
# Save metrics
|
||||
if has_tensorboard and jax.process_index() == 0:
|
||||
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
|
||||
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
|
||||
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
||||
|
||||
# save checkpoint after each epoch and push checkpoint to the hub
|
||||
if jax.process_index() == 0:
|
||||
|
||||
Reference in New Issue
Block a user