read parameters from CLI, load model & tokenizer
This commit is contained in:
@@ -30,12 +30,15 @@ Gao, Ming Zhou, and Hsiao-Wuen Hon. “Unified Language Model Pre-Training for
|
||||
Natural Language Understanding and Generation.” (May 2019) ArXiv:1905.03197
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from transformers import BertConfig, Bert2Rnd, BertTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -43,25 +46,60 @@ def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer):
|
||||
""" Fine-tune the pretrained model on the corpus. """
|
||||
# Data sampler
|
||||
# Data loader
|
||||
# Training
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, prefix=""):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def main():
|
||||
raise NotImplementedError
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument("--train_data_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input training data file (a text file).")
|
||||
parser.add_argument("--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
|
||||
# Optional parameters
|
||||
parser.add_argument("--model_name_or_path",
|
||||
default="bert-base-cased",
|
||||
type=str,
|
||||
help="The model checkpoint for weights initialization.")
|
||||
parser.add_argument("--seed", default=42, type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set up training device
|
||||
device = torch.device("cpu")
|
||||
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
config_class, model_class, tokenizer_class = BertConfig, Bert2Rnd, BertTokenizer
|
||||
config = config_class.from_pretrained(args.model_name_or_path)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
|
||||
model = model_class.from_pretrained(args.model_name_or_path, config=config)
|
||||
model.to(device)
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Training
|
||||
train_dataset = load_and_cache_examples(args, tokenizer)
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
|
||||
def __main__():
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user