[Examples] Added predict stage and Updated Example Template (#10868)

* added predict stage

* added test keyword in exception message

* removed example specific saving predictions

* fixed f-string error

* removed extra line

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
Bhadresh Savani
2021-03-23 23:07:59 +05:30
committed by GitHub
parent fb2b89840b
commit 7ef40120a0
2 changed files with 103 additions and 14 deletions

View File

@@ -207,14 +207,22 @@ def main():
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
# download the dataset.
# Downloading and loading xnli dataset from the hub.
if model_args.train_language is None:
train_dataset = load_dataset("xnli", model_args.language, split="train")
else:
train_dataset = load_dataset("xnli", model_args.train_language, split="train")
if training_args.do_train:
if model_args.train_language is None:
train_dataset = load_dataset("xnli", model_args.language, split="train")
else:
train_dataset = load_dataset("xnli", model_args.train_language, split="train")
label_list = train_dataset.features["label"].names
if training_args.do_eval:
eval_dataset = load_dataset("xnli", model_args.language, split="validation")
label_list = eval_dataset.features["label"].names
if training_args.do_predict:
test_dataset = load_dataset("xnli", model_args.language, split="test")
label_list = test_dataset.features["label"].names
eval_dataset = load_dataset("xnli", model_args.language, split="validation")
# Labels
label_list = train_dataset.features["label"].names
num_labels = len(label_list)
# Load pretrained model and tokenizer
@@ -271,6 +279,9 @@ def main():
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
)
# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3):
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
if training_args.do_eval:
if data_args.max_val_samples is not None:
@@ -281,9 +292,14 @@ def main():
load_from_cache_file=not data_args.overwrite_cache,
)
# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3):
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
if training_args.do_predict:
if data_args.max_test_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples))
test_dataset = test_dataset.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
)
# Get the metric function
metric = load_metric("xnli")
@@ -307,7 +323,7 @@ def main():
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
compute_metrics=compute_metrics,
tokenizer=tokenizer,
@@ -346,6 +362,26 @@ def main():
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Prediction
if training_args.do_predict:
logger.info("*** Predict ***")
predictions, labels, metrics = trainer.predict(test_dataset)
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)
predictions = np.argmax(predictions, axis=1)
output_test_file = os.path.join(training_args.output_dir, "test_predictions.txt")
if trainer.is_world_process_zero():
with open(output_test_file, "w") as writer:
writer.write("index\tprediction\n")
for index, item in enumerate(predictions):
item = label_list[item]
writer.write(f"{index}\t{item}\n")
if __name__ == "__main__":
main()