[style] consistent nn. and nn.functional: part 4 examples (#12156)
* consistent nn. and nn.functional: p4 examples * restore
This commit is contained in:
@@ -9,6 +9,7 @@ import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
@@ -135,11 +136,11 @@ def train(args, train_dataset, model, tokenizer, train_highway=False):
|
||||
|
||||
# multi-gpu training (should be after apex fp16 initialization)
|
||||
if args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
model = nn.DataParallel(model)
|
||||
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model = nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||
)
|
||||
|
||||
@@ -190,9 +191,9 @@ def train(args, train_dataset, model, tokenizer, train_highway=False):
|
||||
tr_loss += loss.item()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
if args.fp16:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||
nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step() # Update learning rate schedule
|
||||
@@ -255,7 +256,7 @@ def evaluate(args, model, tokenizer, prefix="", output_layer=-1, eval_highway=Fa
|
||||
|
||||
# multi-gpu eval
|
||||
if args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
model = nn.DataParallel(model)
|
||||
|
||||
# Eval!
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
|
||||
Reference in New Issue
Block a user