Added max_sample_ arguments (#10551)

* reverted changes of logging and saving metrics

* added max_sample arguments

* fixed code

* white space diff

* reformetting code

* reformatted code
This commit is contained in:
Bhadresh Savani
2021-03-09 00:27:10 +05:30
committed by GitHub
parent 917f104502
commit dfd16af832
14 changed files with 516 additions and 118 deletions

View File

@@ -89,6 +89,27 @@ class DataTrainingArguments:
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_val_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"value if set."
},
)
max_test_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
"value if set."
},
)
train_file: Optional[str] = field(
default=None, metadata={"help": "A csv or a json file containing the training data."}
)
@@ -353,12 +374,41 @@ def main():
result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
return result
datasets = datasets.map(preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache)
if training_args.do_train:
if "train" not in datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = datasets["train"]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
)
train_dataset = datasets["train"]
eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
if data_args.task_name is not None or data_args.test_file is not None:
if training_args.do_eval:
if "validation" not in datasets and "validation_matched" not in datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
eval_dataset = eval_dataset.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
)
if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
if "test" not in datasets and "test_matched" not in datasets:
raise ValueError("--do_predict requires a test dataset")
test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]
if data_args.max_test_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples))
test_dataset = test_dataset.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
)
# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3):
@@ -417,6 +467,10 @@ def main():
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
max_train_samples = (
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
trainer.save_model() # Saves the tokenizer too for easy upload
@@ -425,7 +479,6 @@ def main():
trainer.save_state()
# Evaluation
eval_results = {}
if training_args.do_eval:
logger.info("*** Evaluate ***")
@@ -437,12 +490,13 @@ def main():
eval_datasets.append(datasets["validation_mismatched"])
for eval_dataset, task in zip(eval_datasets, tasks):
eval_result = trainer.evaluate(eval_dataset=eval_dataset)
metrics = trainer.evaluate(eval_dataset=eval_dataset)
trainer.log_metrics("eval", eval_result)
trainer.save_metrics("eval", eval_result)
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
eval_results.update(eval_result)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if training_args.do_predict:
logger.info("*** Test ***")
@@ -471,7 +525,6 @@ def main():
else:
item = label_list[item]
writer.write(f"{index}\t{item}\n")
return eval_results
def _mp_fn(index):

View File

@@ -247,10 +247,18 @@ def main():
results = {}
if training_args.do_eval:
logger.info("*** Evaluate ***")
result = trainer.evaluate()
trainer.log_metrics("eval", result)
trainer.save_metrics("eval", result)
results.update(result)
output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key, value in result.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
results.update(result)
return results

View File

@@ -294,9 +294,16 @@ def main():
if training_args.do_eval:
logger.info("*** Evaluate ***")
result = trainer.evaluate()
trainer.log_metrics("eval", result)
trainer.save_metrics("eval", result)
results.update(result)
output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key, value in result.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
results.update(result)
return results

View File

@@ -73,6 +73,27 @@ class DataTrainingArguments:
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_val_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"value if set."
},
)
max_test_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
"value if set."
},
)
server_ip: Optional[str] = field(default=None, metadata={"help": "For distant debugging."})
server_port: Optional[str] = field(default=None, metadata={"help": "For distant debugging."})
@@ -238,12 +259,23 @@ def main():
truncation=True,
)
train_dataset = train_dataset.map(
preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache
)
eval_dataset = eval_dataset.map(
preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache
)
if training_args.do_train:
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
)
if training_args.do_eval:
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
eval_dataset = eval_dataset.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
)
# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3):
@@ -288,6 +320,10 @@ def main():
model_path = None
train_result = trainer.train(model_path=model_path)
metrics = train_result.metrics
max_train_samples = (
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
trainer.save_model() # Saves the tokenizer too for easy upload
@@ -296,15 +332,15 @@ def main():
trainer.save_state()
# Evaluation
eval_results = {}
if training_args.do_eval:
logger.info("*** Evaluate ***")
eval_result = trainer.evaluate(eval_dataset=eval_dataset)
trainer.log_metrics("eval", eval_result)
trainer.save_metrics("eval", eval_result)
eval_results.update(eval_result)
metrics = trainer.evaluate(eval_dataset=eval_dataset)
return eval_results
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if __name__ == "__main__":