diff --git a/examples/distillation/run_squad_w_distillation.py b/examples/distillation/run_squad_w_distillation.py index 44b802e1c1..4f58807572 100644 --- a/examples/distillation/run_squad_w_distillation.py +++ b/examples/distillation/run_squad_w_distillation.py @@ -123,7 +123,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None): # Load in optimizer and scheduler states optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) - + if args.fp16: try: from apex import amp @@ -744,7 +744,7 @@ def main(): # Load pretrained model and tokenizer if args.local_rank not in [-1, 0]: # Make sure only the first process in distributed training will download model & vocab - torch.distributed.barrier() + torch.distributed.barrier() args.model_type = args.model_type.lower() config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]