indents
This commit is contained in:
committed by
Lysandre Debut
parent
ebd45980a0
commit
3cdb38a7c0
@@ -124,7 +124,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
||||||
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
|
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
|
||||||
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
try:
|
try:
|
||||||
from apex import amp
|
from apex import amp
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -157,7 +157,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
logger.info(" Total optimization steps = %d", t_total)
|
logger.info(" Total optimization steps = %d", t_total)
|
||||||
|
|
||||||
global_step = 1
|
global_step = 1
|
||||||
epochs_trained = 0
|
epochs_trained = 0
|
||||||
steps_trained_in_current_epoch = 0
|
steps_trained_in_current_epoch = 0
|
||||||
# Check if continuing training from a checkpoint
|
# Check if continuing training from a checkpoint
|
||||||
if os.path.exists(args.model_name_or_path):
|
if os.path.exists(args.model_name_or_path):
|
||||||
@@ -178,10 +178,10 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
tr_loss, logging_loss = 0.0, 0.0
|
tr_loss, logging_loss = 0.0, 0.0
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
train_iterator = trange(
|
train_iterator = trange(
|
||||||
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
|
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
|
||||||
)
|
)
|
||||||
# Added here for reproductibility
|
# Added here for reproductibility
|
||||||
set_seed(args)
|
set_seed(args)
|
||||||
|
|
||||||
for _ in train_iterator:
|
for _ in train_iterator:
|
||||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
||||||
@@ -207,7 +207,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2]
|
inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2]
|
||||||
if args.model_type in ["xlnet", "xlm"]:
|
if args.model_type in ["xlnet", "xlm"]:
|
||||||
inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
|
inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
|
||||||
if args.version_2_with_negative:
|
if args.version_2_with_negative:
|
||||||
inputs.update({"is_impossible": batch[7]})
|
inputs.update({"is_impossible": batch[7]})
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
loss, start_logits_stu, end_logits_stu = outputs
|
loss, start_logits_stu, end_logits_stu = outputs
|
||||||
@@ -261,7 +261,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
# Log metrics
|
# Log metrics
|
||||||
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||||
# Only evaluate when single GPU otherwise metrics may not average well
|
# Only evaluate when single GPU otherwise metrics may not average well
|
||||||
if args.local_rank == -1 and args.evaluate_during_training:
|
if args.local_rank == -1 and args.evaluate_during_training:
|
||||||
@@ -281,7 +281,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
|||||||
model.module if hasattr(model, "module") else model
|
model.module if hasattr(model, "module") else model
|
||||||
) # Take care of distributed/parallel training
|
) # Take care of distributed/parallel training
|
||||||
model_to_save.save_pretrained(output_dir)
|
model_to_save.save_pretrained(output_dir)
|
||||||
tokenizer.save_pretrained(output_dir)
|
tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
||||||
logger.info("Saving model checkpoint to %s", output_dir)
|
logger.info("Saving model checkpoint to %s", output_dir)
|
||||||
@@ -325,7 +325,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||||
|
|
||||||
all_results = []
|
all_results = []
|
||||||
start_time = timeit.default_timer()
|
start_time = timeit.default_timer()
|
||||||
|
|
||||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -425,7 +425,7 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
|
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
|
||||||
if args.local_rank not in [-1, 0] and not evaluate:
|
if args.local_rank not in [-1, 0] and not evaluate:
|
||||||
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
|
||||||
# Load data features from cache or dataset file
|
# Load data features from cache or dataset file
|
||||||
input_file = args.predict_file if evaluate else args.train_file
|
input_file = args.predict_file if evaluate else args.train_file
|
||||||
@@ -468,7 +468,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
max_query_length=args.max_query_length,
|
max_query_length=args.max_query_length,
|
||||||
is_training=not evaluate,
|
is_training=not evaluate,
|
||||||
return_dataset="pt",
|
return_dataset="pt",
|
||||||
threads=args.threads,
|
threads=args.threads,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
@@ -476,7 +476,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file)
|
torch.save({"features": features, "dataset": dataset, "examples": examples}, cached_features_file)
|
||||||
|
|
||||||
if args.local_rank == 0 and not evaluate:
|
if args.local_rank == 0 and not evaluate:
|
||||||
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
|
||||||
if output_examples:
|
if output_examples:
|
||||||
@@ -541,11 +541,11 @@ def main():
|
|||||||
help="The input data dir. Should contain the .json files for the task."
|
help="The input data dir. Should contain the .json files for the task."
|
||||||
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
|
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--train_file",
|
"--train_file",
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
help="The input training file. If a data dir is specified, will look for the file there"
|
help="The input training file. If a data dir is specified, will look for the file there"
|
||||||
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
|
+ "If no data dir or train/predict files are specified, will run with tensorflow_datasets.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -688,7 +688,7 @@ def main():
|
|||||||
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
||||||
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
||||||
|
|
||||||
parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features")
|
parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -743,7 +743,7 @@ def main():
|
|||||||
|
|
||||||
# Load pretrained model and tokenizer
|
# Load pretrained model and tokenizer
|
||||||
if args.local_rank not in [-1, 0]:
|
if args.local_rank not in [-1, 0]:
|
||||||
# Make sure only the first process in distributed training will download model & vocab
|
# Make sure only the first process in distributed training will download model & vocab
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
|
||||||
args.model_type = args.model_type.lower()
|
args.model_type = args.model_type.lower()
|
||||||
@@ -781,7 +781,7 @@ def main():
|
|||||||
teacher = None
|
teacher = None
|
||||||
|
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
# Make sure only the first process in distributed training will download model & vocab
|
# Make sure only the first process in distributed training will download model & vocab
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
|
|||||||
Reference in New Issue
Block a user