[run_lm_finetuning] Train from scratch
This commit is contained in:
@@ -28,7 +28,7 @@ import pickle
|
|||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
from typing import Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -54,6 +54,7 @@ from transformers import (
|
|||||||
OpenAIGPTConfig,
|
OpenAIGPTConfig,
|
||||||
OpenAIGPTLMHeadModel,
|
OpenAIGPTLMHeadModel,
|
||||||
OpenAIGPTTokenizer,
|
OpenAIGPTTokenizer,
|
||||||
|
PreTrainedModel,
|
||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
RobertaConfig,
|
RobertaConfig,
|
||||||
RobertaForMaskedLM,
|
RobertaForMaskedLM,
|
||||||
@@ -82,11 +83,11 @@ MODEL_CLASSES = {
|
|||||||
|
|
||||||
|
|
||||||
class TextDataset(Dataset):
|
class TextDataset(Dataset):
|
||||||
def __init__(self, tokenizer, args, file_path="train", block_size=512):
|
def __init__(self, tokenizer: PreTrainedTokenizer, args, file_path="train", block_size=512):
|
||||||
assert os.path.isfile(file_path)
|
assert os.path.isfile(file_path)
|
||||||
directory, filename = os.path.split(file_path)
|
directory, filename = os.path.split(file_path)
|
||||||
cached_features_file = os.path.join(
|
cached_features_file = os.path.join(
|
||||||
directory, args.model_name_or_path + "_cached_lm_" + str(block_size) + "_" + filename
|
directory, args.model_type + "_cached_lm_" + str(block_size) + "_" + filename
|
||||||
)
|
)
|
||||||
|
|
||||||
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
||||||
@@ -120,13 +121,12 @@ class TextDataset(Dataset):
|
|||||||
|
|
||||||
|
|
||||||
def load_and_cache_examples(args, tokenizer, evaluate=False):
|
def load_and_cache_examples(args, tokenizer, evaluate=False):
|
||||||
dataset = TextDataset(
|
return TextDataset(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
args,
|
args,
|
||||||
file_path=args.eval_data_file if evaluate else args.train_data_file,
|
file_path=args.eval_data_file if evaluate else args.train_data_file,
|
||||||
block_size=args.block_size,
|
block_size=args.block_size,
|
||||||
)
|
)
|
||||||
return dataset
|
|
||||||
|
|
||||||
|
|
||||||
def set_seed(args):
|
def set_seed(args):
|
||||||
@@ -137,18 +137,11 @@ def set_seed(args):
|
|||||||
torch.cuda.manual_seed_all(args.seed)
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
|
|
||||||
|
|
||||||
def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False):
|
def _sorted_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> List[str]:
|
||||||
if not args.save_total_limit:
|
|
||||||
return
|
|
||||||
if args.save_total_limit <= 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check if we should delete older checkpoint(s)
|
|
||||||
glob_checkpoints = glob.glob(os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix)))
|
|
||||||
if len(glob_checkpoints) <= args.save_total_limit:
|
|
||||||
return
|
|
||||||
|
|
||||||
ordering_and_checkpoint_path = []
|
ordering_and_checkpoint_path = []
|
||||||
|
|
||||||
|
glob_checkpoints = glob.glob(os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix)))
|
||||||
|
|
||||||
for path in glob_checkpoints:
|
for path in glob_checkpoints:
|
||||||
if use_mtime:
|
if use_mtime:
|
||||||
ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
|
ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
|
||||||
@@ -159,6 +152,20 @@ def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False):
|
|||||||
|
|
||||||
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
|
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
|
||||||
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
|
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
|
||||||
|
return checkpoints_sorted
|
||||||
|
|
||||||
|
|
||||||
|
def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> None:
|
||||||
|
if not args.save_total_limit:
|
||||||
|
return
|
||||||
|
if args.save_total_limit <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if we should delete older checkpoint(s)
|
||||||
|
checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)
|
||||||
|
if len(checkpoints_sorted) <= args.save_total_limit:
|
||||||
|
return
|
||||||
|
|
||||||
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)
|
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)
|
||||||
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
|
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
|
||||||
for checkpoint in checkpoints_to_be_deleted:
|
for checkpoint in checkpoints_to_be_deleted:
|
||||||
@@ -191,7 +198,7 @@ def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> T
|
|||||||
return inputs, labels
|
return inputs, labels
|
||||||
|
|
||||||
|
|
||||||
def train(args, train_dataset, model, tokenizer):
|
def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
|
||||||
""" Train the model """
|
""" Train the model """
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
tb_writer = SummaryWriter()
|
tb_writer = SummaryWriter()
|
||||||
@@ -221,7 +228,7 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check if saved optimizer or scheduler states exist
|
# Check if saved optimizer or scheduler states exist
|
||||||
if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
|
if args.model_name_or_path and os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
|
||||||
os.path.join(args.model_name_or_path, "scheduler.pt")
|
os.path.join(args.model_name_or_path, "scheduler.pt")
|
||||||
):
|
):
|
||||||
# Load in optimizer and scheduler states
|
# Load in optimizer and scheduler states
|
||||||
@@ -263,7 +270,7 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
epochs_trained = 0
|
epochs_trained = 0
|
||||||
steps_trained_in_current_epoch = 0
|
steps_trained_in_current_epoch = 0
|
||||||
# Check if continuing training from a checkpoint
|
# Check if continuing training from a checkpoint
|
||||||
if os.path.exists(args.model_name_or_path):
|
if args.model_name_or_path and os.path.exists(args.model_name_or_path):
|
||||||
try:
|
try:
|
||||||
# set global_step to gobal_step of last saved checkpoint from model path
|
# set global_step to gobal_step of last saved checkpoint from model path
|
||||||
checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
|
checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
|
||||||
@@ -342,8 +349,7 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
checkpoint_prefix = "checkpoint"
|
checkpoint_prefix = "checkpoint"
|
||||||
# Save model checkpoint
|
# Save model checkpoint
|
||||||
output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
|
output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
|
||||||
if not os.path.exists(output_dir):
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
os.makedirs(output_dir)
|
|
||||||
model_to_save = (
|
model_to_save = (
|
||||||
model.module if hasattr(model, "module") else model
|
model.module if hasattr(model, "module") else model
|
||||||
) # Take care of distributed/parallel training
|
) # Take care of distributed/parallel training
|
||||||
@@ -372,14 +378,14 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
return global_step, tr_loss / global_step
|
return global_step, tr_loss / global_step
|
||||||
|
|
||||||
|
|
||||||
def evaluate(args, model, tokenizer, prefix=""):
|
def evaluate(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix="") -> Dict:
|
||||||
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
||||||
eval_output_dir = args.output_dir
|
eval_output_dir = args.output_dir
|
||||||
|
|
||||||
eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)
|
eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)
|
||||||
|
|
||||||
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
os.makedirs(eval_output_dir)
|
os.makedirs(eval_output_dir, exist_ok=True)
|
||||||
|
|
||||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||||
# Note that DistributedSampler samples randomly
|
# Note that DistributedSampler samples randomly
|
||||||
@@ -433,11 +439,16 @@ def main():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output_dir",
|
"--output_dir",
|
||||||
default=None,
|
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="The output directory where the model predictions and checkpoints will be written.",
|
help="The output directory where the model predictions and checkpoints will be written.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_type", type=str, required=True, help="The model architecture to be trained or fine-tuned.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--should_continue", action="store_true", help="Whether to continue from latest checkpoint in output_dir"
|
||||||
|
)
|
||||||
|
|
||||||
# Other parameters
|
# Other parameters
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -447,12 +458,11 @@ def main():
|
|||||||
help="An optional input evaluation data file to evaluate the perplexity on (a text file).",
|
help="An optional input evaluation data file to evaluate the perplexity on (a text file).",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--model_type", default="bert", type=str, help="The model architecture to be fine-tuned.")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_name_or_path",
|
"--model_name_or_path",
|
||||||
default="bert-base-cased",
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
help="The model checkpoint for weights initialization.",
|
help="The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -464,19 +474,25 @@ def main():
|
|||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config_name",
|
"--config_name",
|
||||||
default="",
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
help="Optional pretrained config name or path if not the same as model_name_or_path",
|
help="Optional pretrained config name or path if not the same as model_name_or_path. If both are None, initialize a new config.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokenizer_name",
|
"--tokenizer_name",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path. If both are None, initialize a new tokenizer.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokenizer_init_args",
|
||||||
default="",
|
default="",
|
||||||
type=str,
|
type=str,
|
||||||
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path",
|
help="If instantiating a new tokenizer, comma-separated list of input args to feed the constructor.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--cache_dir",
|
"--cache_dir",
|
||||||
default="",
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
help="Optional directory to store the pre-trained models downloaded from s3 (instead of the default one)",
|
help="Optional directory to store the pre-trained models downloaded from s3 (instead of the default one)",
|
||||||
)
|
)
|
||||||
@@ -493,9 +509,6 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step."
|
"--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.")
|
parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -563,7 +576,7 @@ def main():
|
|||||||
|
|
||||||
if args.model_type in ["bert", "roberta", "distilbert", "camembert"] and not args.mlm:
|
if args.model_type in ["bert", "roberta", "distilbert", "camembert"] and not args.mlm:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
|
"BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the --mlm "
|
||||||
"flag (masked language modeling)."
|
"flag (masked language modeling)."
|
||||||
)
|
)
|
||||||
if args.eval_data_file is None and args.do_eval:
|
if args.eval_data_file is None and args.do_eval:
|
||||||
@@ -571,6 +584,14 @@ def main():
|
|||||||
"Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
|
"Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
|
||||||
"or remove the --do_eval argument."
|
"or remove the --do_eval argument."
|
||||||
)
|
)
|
||||||
|
if args.should_continue:
|
||||||
|
sorted_checkpoints = _sorted_checkpoints(args)
|
||||||
|
if len(sorted_checkpoints) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Used --should_continue but no checkpoint was found in --output_dir."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
args.model_name_or_path = sorted_checkpoints[-1]
|
||||||
|
|
||||||
if (
|
if (
|
||||||
os.path.exists(args.output_dir)
|
os.path.exists(args.output_dir)
|
||||||
@@ -627,26 +648,42 @@ def main():
|
|||||||
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
|
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
|
||||||
|
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||||
config = config_class.from_pretrained(
|
|
||||||
args.config_name if args.config_name else args.model_name_or_path,
|
if args.config_name:
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
config = config_class.from_pretrained(args.config_name, cache_dir=args.cache_dir)
|
||||||
)
|
elif args.model_name_or_path:
|
||||||
tokenizer = tokenizer_class.from_pretrained(
|
config = config_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
||||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
else:
|
||||||
do_lower_case=args.do_lower_case,
|
config = config_class()
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
|
||||||
|
if args.tokenizer_name:
|
||||||
|
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)
|
||||||
|
elif args.model_name_or_path:
|
||||||
|
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"You are instantiating a new {} tokenizer from scratch. Are you sure this is what you meant to do?"
|
||||||
|
"To specifiy a pretrained tokenizer name, use --tokenizer_name".format(tokenizer_class.__name__)
|
||||||
)
|
)
|
||||||
|
tokenizer = tokenizer_class(*args.tokenizer_init_args.split(","))
|
||||||
|
|
||||||
if args.block_size <= 0:
|
if args.block_size <= 0:
|
||||||
args.block_size = (
|
args.block_size = tokenizer.max_len_single_sentence
|
||||||
tokenizer.max_len_single_sentence
|
# Our input block size will be the max possible for the model
|
||||||
) # Our input block size will be the max possible for the model
|
else:
|
||||||
args.block_size = min(args.block_size, tokenizer.max_len_single_sentence)
|
args.block_size = min(args.block_size, tokenizer.max_len_single_sentence)
|
||||||
|
|
||||||
|
if args.model_name_or_path:
|
||||||
model = model_class.from_pretrained(
|
model = model_class.from_pretrained(
|
||||||
args.model_name_or_path,
|
args.model_name_or_path,
|
||||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||||
config=config,
|
config=config,
|
||||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
cache_dir=args.cache_dir,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
logger.info("Training new model from scratch")
|
||||||
|
model = model_class(config=config)
|
||||||
|
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
@@ -670,8 +707,8 @@ def main():
|
|||||||
# Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
|
# Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
|
||||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
# Create output directory if needed
|
# Create output directory if needed
|
||||||
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
os.makedirs(args.output_dir)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||||
@@ -687,7 +724,7 @@ def main():
|
|||||||
|
|
||||||
# Load a trained model and vocabulary that you have fine-tuned
|
# Load a trained model and vocabulary that you have fine-tuned
|
||||||
model = model_class.from_pretrained(args.output_dir)
|
model = model_class.from_pretrained(args.output_dir)
|
||||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
|
|||||||
Reference in New Issue
Block a user