New features for CodeParrot training script (#16851)
* add tflops logging and fix grad accumulation * add accelerate tracking and checkpointing * scale loss of last batch correctly * fix typo * compress loss computation Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> * add resume from checkpoint argument * add load_state accelerate from checkpoint, register lr scheduler and add tflops function * reformat code * reformat code * add condition on path for resume checkpoint * combine if conditions Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> * add source for tflops formula Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
This commit is contained in:
@@ -82,7 +82,7 @@ Now that the dataset, tokenizer, and model are ready we can start training the m
|
|||||||
First you need to configure `accelerate` and login to Weights & Biases:
|
First you need to configure `accelerate` and login to Weights & Biases:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
acclerate config
|
accelerate config
|
||||||
wandb login
|
wandb login
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -49,6 +49,10 @@ class TrainingArguments:
|
|||||||
default=1024,
|
default=1024,
|
||||||
metadata={"help": "Interval to save checkpoints. Measured as number of forward passes not training steps."},
|
metadata={"help": "Interval to save checkpoints. Measured as number of forward passes not training steps."},
|
||||||
)
|
)
|
||||||
|
resume_from_checkpoint: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "States path if the training should continue from a checkpoint folder."},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -7,11 +9,9 @@ import torch
|
|||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from torch.utils.data import IterableDataset
|
from torch.utils.data import IterableDataset
|
||||||
from torch.utils.data.dataloader import DataLoader
|
from torch.utils.data.dataloader import DataLoader
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
import wandb
|
from accelerate import Accelerator, DistributedType
|
||||||
from accelerate import Accelerator
|
|
||||||
from arguments import TrainingArguments
|
from arguments import TrainingArguments
|
||||||
from huggingface_hub import Repository
|
from huggingface_hub import Repository
|
||||||
from transformers import AdamW, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed
|
from transformers import AdamW, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, get_scheduler, set_seed
|
||||||
@@ -39,6 +39,7 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
self.input_characters = seq_length * chars_per_token * num_of_sequences
|
self.input_characters = seq_length * chars_per_token * num_of_sequences
|
||||||
self.epoch = 0
|
self.epoch = 0
|
||||||
self.infinite = infinite
|
self.infinite = infinite
|
||||||
|
self.current_size = 0
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
iterator = iter(self.dataset)
|
iterator = iter(self.dataset)
|
||||||
@@ -66,6 +67,7 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
for i in range(0, len(all_token_ids), self.seq_length):
|
for i in range(0, len(all_token_ids), self.seq_length):
|
||||||
input_ids = all_token_ids[i : i + self.seq_length]
|
input_ids = all_token_ids[i : i + self.seq_length]
|
||||||
if len(input_ids) == self.seq_length:
|
if len(input_ids) == self.seq_length:
|
||||||
|
self.current_size += 1
|
||||||
yield torch.tensor(input_ids)
|
yield torch.tensor(input_ids)
|
||||||
|
|
||||||
|
|
||||||
@@ -82,20 +84,17 @@ def setup_logging(args):
|
|||||||
handlers=[logging.FileHandler(log_dir / filename), logging.StreamHandler()],
|
handlers=[logging.FileHandler(log_dir / filename), logging.StreamHandler()],
|
||||||
)
|
)
|
||||||
if accelerator.is_main_process: # we only want to setup logging once
|
if accelerator.is_main_process: # we only want to setup logging once
|
||||||
wandb.init(project=project_name, config=args)
|
accelerator.init_trackers(project_name, vars(args))
|
||||||
run_name = wandb.run.name
|
run_name = accelerator.trackers[0].run.name
|
||||||
tb_writer = SummaryWriter()
|
|
||||||
tb_writer.add_hparams(vars(args), {"0": 0})
|
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
datasets.utils.logging.set_verbosity_info()
|
datasets.utils.logging.set_verbosity_info()
|
||||||
transformers.utils.logging.set_verbosity_info()
|
transformers.utils.logging.set_verbosity_info()
|
||||||
else:
|
else:
|
||||||
tb_writer = None
|
|
||||||
run_name = ""
|
run_name = ""
|
||||||
logger.setLevel(logging.ERROR)
|
logger.setLevel(logging.ERROR)
|
||||||
datasets.utils.logging.set_verbosity_error()
|
datasets.utils.logging.set_verbosity_error()
|
||||||
transformers.utils.logging.set_verbosity_error()
|
transformers.utils.logging.set_verbosity_error()
|
||||||
return logger, tb_writer, run_name
|
return logger, run_name
|
||||||
|
|
||||||
|
|
||||||
def create_dataloaders(args):
|
def create_dataloaders(args):
|
||||||
@@ -126,8 +125,22 @@ def get_grouped_params(model, args, no_decay=["bias", "LayerNorm.weight"]):
|
|||||||
def log_metrics(step, metrics):
|
def log_metrics(step, metrics):
|
||||||
logger.info(f"Step {step}: {metrics}")
|
logger.info(f"Step {step}: {metrics}")
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
wandb.log(metrics)
|
accelerator.log(metrics, step)
|
||||||
[tb_writer.add_scalar(k, v, step) for k, v in metrics.items()]
|
|
||||||
|
|
||||||
|
def compute_tflops(elapsed_time, accelerator, args):
|
||||||
|
# TFLOPs formula (from Equation 3 in Section 5.1 of https://arxiv.org/pdf/2104.04473.pdf).
|
||||||
|
config_model = accelerator.unwrap_model(model).config
|
||||||
|
checkpoint_factor = 4 if args.gradient_checkpointing else 3
|
||||||
|
batch_size = args.train_batch_size * accelerator.state.num_processes * args.gradient_accumulation_steps
|
||||||
|
factor = 24 * checkpoint_factor * batch_size * args.seq_length * config_model.n_layer * (config_model.n_embd**2)
|
||||||
|
flops_per_iteration = factor * (
|
||||||
|
1.0
|
||||||
|
+ (args.seq_length / (6.0 * config_model.n_embd))
|
||||||
|
+ (tokenizer.vocab_size / (16.0 * config_model.n_layer * config_model.n_embd))
|
||||||
|
)
|
||||||
|
tflops = flops_per_iteration / (elapsed_time * accelerator.state.num_processes * (10**12))
|
||||||
|
return tflops
|
||||||
|
|
||||||
|
|
||||||
def evaluate(args):
|
def evaluate(args):
|
||||||
@@ -140,7 +153,8 @@ def evaluate(args):
|
|||||||
losses.append(accelerator.gather(loss))
|
losses.append(accelerator.gather(loss))
|
||||||
if args.max_eval_steps > 0 and step >= args.max_eval_steps:
|
if args.max_eval_steps > 0 and step >= args.max_eval_steps:
|
||||||
break
|
break
|
||||||
loss = torch.mean(torch.cat(losses))
|
losses = torch.cat(losses)
|
||||||
|
loss = losses[: eval_dataloader.dataset.current_size].mean()
|
||||||
try:
|
try:
|
||||||
perplexity = torch.exp(loss)
|
perplexity = torch.exp(loss)
|
||||||
except OverflowError:
|
except OverflowError:
|
||||||
@@ -149,7 +163,7 @@ def evaluate(args):
|
|||||||
|
|
||||||
|
|
||||||
# Accelerator
|
# Accelerator
|
||||||
accelerator = Accelerator()
|
accelerator = Accelerator(log_with=["wandb", "tensorboard"])
|
||||||
acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()}
|
acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()}
|
||||||
|
|
||||||
# Settings
|
# Settings
|
||||||
@@ -165,7 +179,7 @@ if accelerator.is_main_process:
|
|||||||
hf_repo = Repository(args.save_dir, clone_from=args.model_ckpt)
|
hf_repo = Repository(args.save_dir, clone_from=args.model_ckpt)
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
logger, tb_writer, run_name = setup_logging(args)
|
logger, run_name = setup_logging(args)
|
||||||
logger.info(accelerator.state)
|
logger.info(accelerator.state)
|
||||||
|
|
||||||
# Checkout new branch on repo
|
# Checkout new branch on repo
|
||||||
@@ -189,6 +203,7 @@ lr_scheduler = get_scheduler(
|
|||||||
num_warmup_steps=args.num_warmup_steps,
|
num_warmup_steps=args.num_warmup_steps,
|
||||||
num_training_steps=args.max_train_steps,
|
num_training_steps=args.max_train_steps,
|
||||||
)
|
)
|
||||||
|
accelerator.register_for_checkpointing(lr_scheduler)
|
||||||
|
|
||||||
|
|
||||||
def get_lr():
|
def get_lr():
|
||||||
@@ -200,29 +215,58 @@ model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
|
|||||||
model, optimizer, train_dataloader, eval_dataloader
|
model, optimizer, train_dataloader, eval_dataloader
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# load in the weights and states from a previous save
|
||||||
|
if args.resume_from_checkpoint:
|
||||||
|
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
|
||||||
|
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
|
||||||
|
accelerator.load_state(args.resume_from_checkpoint)
|
||||||
|
path = os.path.basename(args.resume_from_checkpoint)
|
||||||
|
else:
|
||||||
|
# Get the most recent checkpoint
|
||||||
|
dirs = [f.name for f in os.scandir(args.save_dir) if f.is_dir() and "step" in str(f)]
|
||||||
|
dirs.sort(key=os.path.getctime)
|
||||||
|
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
|
||||||
|
# Extract the step of the checkpoint to continue from there
|
||||||
|
training_difference = os.path.splitext(path)[0]
|
||||||
|
resume_step = int(training_difference.replace("step_", ""))
|
||||||
|
|
||||||
# Train model
|
# Train model
|
||||||
model.train()
|
model.train()
|
||||||
completed_steps = 0
|
completed_steps = 0
|
||||||
|
t_start = time.time()
|
||||||
for step, batch in enumerate(train_dataloader, start=1):
|
for step, batch in enumerate(train_dataloader, start=1):
|
||||||
|
if args.resume_from_checkpoint and step < resume_step:
|
||||||
|
continue # we need to skip steps until we reach the resumed step
|
||||||
loss = model(batch, labels=batch, use_cache=False).loss
|
loss = model(batch, labels=batch, use_cache=False).loss
|
||||||
log_metrics(
|
log_metrics(
|
||||||
step, {"lr": get_lr(), "samples": step * samples_per_step, "steps": completed_steps, "loss/train": loss.item()}
|
step, {"lr": get_lr(), "samples": step * samples_per_step, "steps": completed_steps, "loss/train": loss.item()}
|
||||||
)
|
)
|
||||||
loss = loss / args.gradient_accumulation_steps
|
loss = loss / args.gradient_accumulation_steps
|
||||||
|
if step % args.gradient_accumulation_steps != 0:
|
||||||
|
# Prevent backward from doing gradient all_reduce in every step
|
||||||
|
if accelerator.distributed_type == DistributedType.MULTI_GPU:
|
||||||
|
with model.no_sync():
|
||||||
|
accelerator.backward(loss)
|
||||||
|
else:
|
||||||
|
accelerator.backward(loss)
|
||||||
|
else:
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
if step % args.gradient_accumulation_steps == 0:
|
|
||||||
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
completed_steps += 1
|
completed_steps += 1
|
||||||
|
elapsed_time = time.time() - t_start
|
||||||
|
tflops = compute_tflops(elapsed_time, accelerator, args)
|
||||||
|
log_metrics(step, {"steps": completed_steps, "tflops": tflops, "time_per_iteration": elapsed_time})
|
||||||
|
t_start = time.time()
|
||||||
if step % args.save_checkpoint_steps == 0:
|
if step % args.save_checkpoint_steps == 0:
|
||||||
logger.info("Evaluating and saving model checkpoint")
|
logger.info("Evaluating and saving model checkpoint")
|
||||||
eval_loss, perplexity = evaluate(args)
|
eval_loss, perplexity = evaluate(args)
|
||||||
log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
|
log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
unwrapped_model = accelerator.unwrap_model(model)
|
save_dir = os.path.join(args.save_dir, f"step_{step}")
|
||||||
unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save)
|
accelerator.save_state(save_dir)
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
hf_repo.push_to_hub(commit_message=f"step {step}")
|
hf_repo.push_to_hub(commit_message=f"step {step}")
|
||||||
model.train()
|
model.train()
|
||||||
@@ -236,5 +280,7 @@ log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
|
|||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
unwrapped_model = accelerator.unwrap_model(model)
|
unwrapped_model = accelerator.unwrap_model(model)
|
||||||
unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save)
|
unwrapped_model.save_pretrained(args.save_dir, save_function=accelerator.save)
|
||||||
|
save_dir = os.path.join(args.save_dir, f"step_{step}")
|
||||||
|
accelerator.save_state(save_dir)
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
hf_repo.push_to_hub(commit_message="final model")
|
hf_repo.push_to_hub(commit_message="final model")
|
||||||
|
|||||||
Reference in New Issue
Block a user