update run_glue
This commit is contained in:
@@ -69,7 +69,11 @@ def train(args, train_dataset, model):
|
|||||||
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
||||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||||
|
|
||||||
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
if args.max_steps > 0:
|
||||||
|
num_train_optimization_steps = args.max_steps
|
||||||
|
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
||||||
|
else:
|
||||||
|
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
||||||
|
|
||||||
# Prepare optimizer
|
# Prepare optimizer
|
||||||
param_optimizer = list(model.named_parameters())
|
param_optimizer = list(model.named_parameters())
|
||||||
@@ -91,10 +95,8 @@ def train(args, train_dataset, model):
|
|||||||
warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion, t_total=num_train_optimization_steps)
|
warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion, t_total=num_train_optimization_steps)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
optimizer = BertAdam(optimizer_grouped_parameters,
|
optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion,
|
||||||
lr=args.learning_rate,
|
t_total=num_train_optimization_steps)
|
||||||
warmup=args.warmup_proportion,
|
|
||||||
t_total=num_train_optimization_steps)
|
|
||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
@@ -113,7 +115,7 @@ def train(args, train_dataset, model):
|
|||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
inputs = {'input_ids': batch[0],
|
inputs = {'input_ids': batch[0],
|
||||||
'attention_mask': batch[1],
|
'attention_mask': batch[1],
|
||||||
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None,
|
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
|
||||||
'labels': batch[3]}
|
'labels': batch[3]}
|
||||||
ouputs = model(**inputs)
|
ouputs = model(**inputs)
|
||||||
loss = ouputs[0]
|
loss = ouputs[0]
|
||||||
@@ -140,14 +142,16 @@ def train(args, train_dataset, model):
|
|||||||
if not args.fp16:
|
if not args.fp16:
|
||||||
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
|
tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
|
||||||
tb_writer.add_scalar('loss', loss.item(), global_step)
|
tb_writer.add_scalar('loss', loss.item(), global_step)
|
||||||
|
if args.max_steps > 0 and global_step > args.max_steps:
|
||||||
|
break
|
||||||
|
if args.max_steps > 0 and global_step > args.max_steps:
|
||||||
|
break
|
||||||
|
|
||||||
return global_step, tr_loss / global_step
|
return global_step, tr_loss / global_step
|
||||||
|
|
||||||
|
|
||||||
def evalutate(args, eval_task, eval_output_dir, dataset, model):
|
def evalutate(args, eval_task, eval_output_dir, dataset, model):
|
||||||
""" Evaluate the model """
|
""" Evaluate the model """
|
||||||
if os.path.exists(eval_output_dir) and os.listdir(eval_output_dir) and args.do_train and not args.overwrite_output_dir:
|
|
||||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(eval_output_dir))
|
|
||||||
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
||||||
os.makedirs(eval_output_dir)
|
os.makedirs(eval_output_dir)
|
||||||
|
|
||||||
@@ -166,13 +170,13 @@ def evalutate(args, eval_task, eval_output_dir, dataset, model):
|
|||||||
out_label_ids = None
|
out_label_ids = None
|
||||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
input_ids, input_mask, segment_ids, label_ids = batch
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(input_ids,
|
inputs = {'input_ids': batch[0],
|
||||||
token_type_ids=segment_ids,
|
'attention_mask': batch[1],
|
||||||
attention_mask=input_mask,
|
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
|
||||||
labels=label_ids)
|
'labels': batch[3]}
|
||||||
|
outputs = model(**inputs)
|
||||||
tmp_eval_loss, logits = outputs[:2]
|
tmp_eval_loss, logits = outputs[:2]
|
||||||
|
|
||||||
eval_loss += tmp_eval_loss.mean().item()
|
eval_loss += tmp_eval_loss.mean().item()
|
||||||
@@ -276,6 +280,8 @@ def main():
|
|||||||
help="The initial learning rate for Adam.")
|
help="The initial learning rate for Adam.")
|
||||||
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
parser.add_argument("--num_train_epochs", default=3.0, type=float,
|
||||||
help="Total number of training epochs to perform.")
|
help="Total number of training epochs to perform.")
|
||||||
|
parser.add_argument("--max_steps", default=-1, type=int,
|
||||||
|
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
||||||
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
||||||
help="Proportion of training with linear learning rate warmup (0.1 = 10%% of training).")
|
help="Proportion of training with linear learning rate warmup (0.1 = 10%% of training).")
|
||||||
parser.add_argument("--no_cuda", action='store_true',
|
parser.add_argument("--no_cuda", action='store_true',
|
||||||
@@ -299,6 +305,9 @@ def main():
|
|||||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
||||||
|
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
||||||
|
|
||||||
# Setup distant debugging if needed
|
# Setup distant debugging if needed
|
||||||
if args.server_ip and args.server_port:
|
if args.server_ip and args.server_port:
|
||||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||||
@@ -320,8 +329,8 @@ def main():
|
|||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
logging.basicConfig(level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
logging.basicConfig(level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
||||||
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
|
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||||
device, args.n_gpu, bool(args.local_rank != -1), args.fp16))
|
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
|
||||||
|
|
||||||
# Setup seeds
|
# Setup seeds
|
||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
@@ -375,8 +384,6 @@ def main():
|
|||||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
# Create output directory if needed
|
# Create output directory if needed
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
|
|
||||||
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
|
|
||||||
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
||||||
os.makedirs(args.output_dir)
|
os.makedirs(args.output_dir)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user