[Examples] Fixes inconsistency around eval vs val and predict vs test (#11380)
* added changes for uniformity * modified files * corrected typo * fixed qa scripts * fix typos * fixed predict typo in qa no trainer * fixed test file * reverted trainer changes * reverted trainer changes in custom exmaples * updated readme * added changes in deepspeed test * added changes for predict and eval
This commit is contained in:
@@ -164,17 +164,17 @@ class DataTrainingArguments:
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_val_samples: Optional[int] = field(
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_test_samples: Optional[int] = field(
|
||||
max_predict_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
||||
"help": "For debugging purposes or quicker training, truncate the number of predict examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
@@ -468,13 +468,13 @@ def main():
|
||||
|
||||
if "validation" in datasets:
|
||||
eval_dataset = datasets["validation"]
|
||||
if data_args.max_val_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
if data_args.max_eval_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
|
||||
if "test" in datasets:
|
||||
test_dataset = datasets["test"]
|
||||
if data_args.max_test_samples is not None:
|
||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
||||
predict_dataset = datasets["test"]
|
||||
if data_args.max_predict_samples is not None:
|
||||
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -513,15 +513,15 @@ def main():
|
||||
|
||||
# region Prediction
|
||||
if "test" in datasets:
|
||||
logger.info("Doing predictions on test dataset...")
|
||||
logger.info("Doing predictions on Predict dataset...")
|
||||
|
||||
test_dataset = DataSequence(
|
||||
test_dataset, non_label_column_names, batch_size=training_args.per_device_eval_batch_size, labels=False
|
||||
predict_dataset = DataSequence(
|
||||
predict_dataset, non_label_column_names, batch_size=training_args.per_device_eval_batch_size, labels=False
|
||||
)
|
||||
predictions = model.predict(test_dataset)["logits"]
|
||||
predictions = model.predict(predict_dataset)["logits"]
|
||||
predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
|
||||
output_test_file = os.path.join(training_args.output_dir, "test_results.txt")
|
||||
with open(output_test_file, "w") as writer:
|
||||
output_predict_file = os.path.join(training_args.output_dir, "predict_results.txt")
|
||||
with open(output_predict_file, "w") as writer:
|
||||
writer.write("index\tprediction\n")
|
||||
for index, item in enumerate(predictions):
|
||||
if is_regression:
|
||||
@@ -529,7 +529,7 @@ def main():
|
||||
else:
|
||||
item = model.config.id2label[item]
|
||||
writer.write(f"{index}\t{item}\n")
|
||||
logger.info(f"Wrote predictions to {output_test_file}!")
|
||||
logger.info(f"Wrote predictions to {output_predict_file}!")
|
||||
# endregion
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user