indents test
This commit is contained in:
committed by
Lysandre Debut
parent
3cdb38a7c0
commit
414e9e7122
@@ -123,7 +123,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
# Load in optimizer and scheduler states
|
# Load in optimizer and scheduler states
|
||||||
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
||||||
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
|
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
|
||||||
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
try:
|
try:
|
||||||
from apex import amp
|
from apex import amp
|
||||||
@@ -744,7 +744,7 @@ def main():
|
|||||||
# Load pretrained model and tokenizer
|
# Load pretrained model and tokenizer
|
||||||
if args.local_rank not in [-1, 0]:
|
if args.local_rank not in [-1, 0]:
|
||||||
# Make sure only the first process in distributed training will download model & vocab
|
# Make sure only the first process in distributed training will download model & vocab
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
|
||||||
args.model_type = args.model_type.lower()
|
args.model_type = args.model_type.lower()
|
||||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||||
|
|||||||
Reference in New Issue
Block a user