New features for CodeParrot training script (#16851)
* add tflops logging and fix grad accumulation * add accelerate tracking and checkpointing * scale loss of last batch correctly * fix typo * compress loss computation Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> * add resume from checkpoint argument * add load_state accelerate from checkpoint, register lr scheduler and add tflops function * reformat code * reformat code * add condition on path for resume checkpoint * combine if conditions Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> * add source for tflops formula Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
This commit is contained in:
@@ -49,6 +49,10 @@ class TrainingArguments:
|
||||
default=1024,
|
||||
metadata={"help": "Interval to save checkpoints. Measured as number of forward passes not training steps."},
|
||||
)
|
||||
resume_from_checkpoint: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "States path if the training should continue from a checkpoint folder."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
@@ -7,11 +9,9 @@ import torch
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import IterableDataset
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
import transformers
|
||||
import wandb
|
||||
from accelerate import Accelerator
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from arguments import TrainingArguments
|
||||
from huggingface_hub import Repository
|
||||
from transformers import AdamW, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed
|
||||
@@ -39,6 +39,7 @@ class ConstantLengthDataset(IterableDataset):
|
||||
self.input_characters = seq_length * chars_per_token * num_of_sequences
|
||||
self.epoch = 0
|
||||
self.infinite = infinite
|
||||
self.current_size = 0
|
||||
|
||||
def __iter__(self):
|
||||
iterator = iter(self.dataset)
|
||||
@@ -66,6 +67,7 @@ class ConstantLengthDataset(IterableDataset):
|
||||
for i in range(0, len(all_token_ids), self.seq_length):
|
||||
input_ids = all_token_ids[i : i + self.seq_length]
|
||||
if len(input_ids) == self.seq_length:
|
||||
self.current_size += 1
|
||||
yield torch.tensor(input_ids)
|
||||
|
||||
|
||||
@@ -82,20 +84,17 @@ def setup_logging(args):
|
||||
handlers=[logging.FileHandler(log_dir / filename), logging.StreamHandler()],
|
||||
)
|
||||
if accelerator.is_main_process: # we only want to setup logging once
|
||||
wandb.init(project=project_name, config=args)
|
||||
run_name = wandb.run.name
|
||||
tb_writer = SummaryWriter()
|
||||
tb_writer.add_hparams(vars(args), {"0": 0})
|
||||
accelerator.init_trackers(project_name, vars(args))
|
||||
run_name = accelerator.trackers[0].run.name
|
||||
logger.setLevel(logging.INFO)
|
||||
datasets.utils.logging.set_verbosity_info()
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
tb_writer = None
|
||||
run_name = ""
|
||||
logger.setLevel(logging.ERROR)
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
return logger, tb_writer, run_name
|
||||
return logger, run_name
|
||||
|
||||
|
||||
def create_dataloaders(args):
|
||||
@@ -126,8 +125,22 @@ def get_grouped_params(model, args, no_decay=["bias", "LayerNorm.weight"]):
|
||||
def log_metrics(step, metrics):
|
||||
logger.info(f"Step {step}: {metrics}")
|
||||
if accelerator.is_main_process:
|
||||
wandb.log(metrics)
|
||||
[tb_writer.add_scalar(k, v, step) for k, v in metrics.items()]
|
||||
accelerator.log(metrics, step)
|
||||
|
||||
|
||||
def compute_tflops(elapsed_time, accelerator, args):
|
||||
# TFLOPs formula (from Equation 3 in Section 5.1 of https://arxiv.org/pdf/2104.04473.pdf).
|
||||
config_model = accelerator.unwrap_model(model).config
|
||||
checkpoint_factor = 4 if args.gradient_checkpointing else 3
|
||||
batch_size = args.train_batch_size * accelerator.state.num_processes * args.gradient_accumulation_steps
|
||||
factor = 24 * checkpoint_factor * batch_size * args.seq_length * config_model.n_layer * (config_model.n_embd**2)
|
||||
flops_per_iteration = factor * (
|
||||
1.0
|
||||
+ (args.seq_length / (6.0 * config_model.n_embd))
|
||||
+ (tokenizer.vocab_size / (16.0 * config_model.n_layer * config_model.n_embd))
|
||||
)
|
||||
tflops = flops_per_iteration / (elapsed_time * accelerator.state.num_processes * (10**12))
|
||||
return tflops
|
||||
|
||||
|
||||
def evaluate(args):
|
||||
@@ -140,7 +153,8 @@ def evaluate(args):
|
||||
losses.append(accelerator.gather(loss))
|
||||
if args.max_eval_steps > 0 and step >= args.max_eval_steps:
|
||||
break
|
||||
loss = torch.mean(torch.cat(losses))
|
||||
losses = torch.cat(losses)
|
||||
loss = losses[: eval_dataloader.dataset.current_size].mean()
|
||||
try:
|
||||
perplexity = torch.exp(loss)
|
||||
except OverflowError:
|
||||
@@ -149,7 +163,7 @@ def evaluate(args):
|
||||
|
||||
|
||||
# Accelerator
|
||||
accelerator = Accelerator()
|
||||
accelerator = Accelerator(log_with=["wandb", "tensorboard"])
|
||||
acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()}
|
||||
|
||||
# Settings
|
||||
@@ -165,7 +179,7 @@ if accelerator.is_main_process:
|
||||
hf_repo = Repository(args.save_dir, clone_from=args.model_ckpt)
|
||||
|
||||
# Logging
|
||||
logger, tb_writer, run_name = setup_logging(args)
|
||||
logger, run_name = setup_logging(args)
|
||||
logger.info(accelerator.state)
|
||||
|
||||
# Checkout new branch on repo
|
||||
@@ -189,6 +203,7 @@ lr_scheduler = get_scheduler(
|
||||
num_warmup_steps=args.num_warmup_steps,
|
||||
num_training_steps=args.max_train_steps,
|
||||
)
|
||||
accelerator.register_for_checkpointing(lr_scheduler)
|
||||
|
||||
|
||||
def get_lr():
|
||||
@@ -200,29 +215,58 @@ model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, eval_dataloader
|
||||
)
|
||||
|
||||
# load in the weights and states from a previous save
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
|
||||
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
|
||||
accelerator.load_state(args.resume_from_checkpoint)
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
else:
|
||||
# Get the most recent checkpoint
|
||||
dirs = [f.name for f in os.scandir(args.save_dir) if f.is_dir() and "step" in str(f)]
|
||||
dirs.sort(key=os.path.getctime)
|
||||
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
|
||||
# Extract the step of the checkpoint to continue from there
|
||||
training_difference = os.path.splitext(path)[0]
|
||||
resume_step = int(training_difference.replace("step_", ""))
|
||||
|
||||
# Train model
|
||||
model.train()
|
||||
completed_steps = 0
|
||||
t_start = time.time()
|
||||
for step, batch in enumerate(train_dataloader, start=1):
|
||||
if args.resume_from_checkpoint and step < resume_step:
|
||||
continue # we need to skip steps until we reach the resumed step
|
||||
loss = model(batch, labels=batch, use_cache=False).loss
|
||||
log_metrics(
|
||||
step, {"lr": get_lr(), "samples": step * samples_per_step, "steps": completed_steps, "loss/train": loss.item()}
|
||||
)
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
accelerator.backward(loss)
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
if step % args.gradient_accumulation_steps != 0:
|
||||
# Prevent backward from doing gradient all_reduce in every step
|
||||
if accelerator.distributed_type == DistributedType.MULTI_GPU:
|
||||
with model.no_sync():
|
||||
accelerator.backward(loss)
|
||||
else:
|
||||
accelerator.backward(loss)
|
||||
else:
|
||||
accelerator.backward(loss)
|
||||
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
completed_steps += 1
|
||||
elapsed_time = time.time() - t_start
|
||||
tflops = compute_tflops(elapsed_time, accelerator, args)
|
||||
log_metrics(step, {"steps": completed_steps, "tflops": tflops, "time_per_iteration": elapsed_time})
|
||||
t_start = time.time()
|
||||
if step % args.save_checkpoint_steps == 0:
|
||||
logger.info("Evaluating and saving model checkpoint")
|
||||
eval_loss, perplexity = evaluate(args)
|
||||
log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save)
|
||||
save_dir = os.path.join(args.save_dir, f"step_{step}")
|
||||
accelerator.save_state(save_dir)
|
||||
if accelerator.is_main_process:
|
||||
hf_repo.push_to_hub(commit_message=f"step {step}")
|
||||
model.train()
|
||||
@@ -236,5 +280,7 @@ log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
|
||||
accelerator.wait_for_everyone()
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save)
|
||||
save_dir = os.path.join(args.save_dir, f"step_{step}")
|
||||
accelerator.save_state(save_dir)
|
||||
if accelerator.is_main_process:
|
||||
hf_repo.push_to_hub(commit_message="final model")
|
||||
|
||||
Reference in New Issue
Block a user