[Examples] Added predict stage and Updated Example Template (#10868)
* added predict stage * added test keyword in exception message * removed example specific saving predictions * fixed f-string error * removed extra line Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
@@ -207,14 +207,22 @@ def main():
|
||||
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
||||
# download the dataset.
|
||||
# Downloading and loading xnli dataset from the hub.
|
||||
if model_args.train_language is None:
|
||||
train_dataset = load_dataset("xnli", model_args.language, split="train")
|
||||
else:
|
||||
train_dataset = load_dataset("xnli", model_args.train_language, split="train")
|
||||
if training_args.do_train:
|
||||
if model_args.train_language is None:
|
||||
train_dataset = load_dataset("xnli", model_args.language, split="train")
|
||||
else:
|
||||
train_dataset = load_dataset("xnli", model_args.train_language, split="train")
|
||||
label_list = train_dataset.features["label"].names
|
||||
|
||||
if training_args.do_eval:
|
||||
eval_dataset = load_dataset("xnli", model_args.language, split="validation")
|
||||
label_list = eval_dataset.features["label"].names
|
||||
|
||||
if training_args.do_predict:
|
||||
test_dataset = load_dataset("xnli", model_args.language, split="test")
|
||||
label_list = test_dataset.features["label"].names
|
||||
|
||||
eval_dataset = load_dataset("xnli", model_args.language, split="validation")
|
||||
# Labels
|
||||
label_list = train_dataset.features["label"].names
|
||||
num_labels = len(label_list)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
@@ -271,6 +279,9 @@ def main():
|
||||
batched=True,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
# Log a few random samples from the training set:
|
||||
for index in random.sample(range(len(train_dataset)), 3):
|
||||
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
||||
|
||||
if training_args.do_eval:
|
||||
if data_args.max_val_samples is not None:
|
||||
@@ -281,9 +292,14 @@ def main():
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
# Log a few random samples from the training set:
|
||||
for index in random.sample(range(len(train_dataset)), 3):
|
||||
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
||||
if training_args.do_predict:
|
||||
if data_args.max_test_samples is not None:
|
||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
||||
test_dataset = test_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
# Get the metric function
|
||||
metric = load_metric("xnli")
|
||||
@@ -307,7 +323,7 @@ def main():
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
train_dataset=train_dataset if training_args.do_train else None,
|
||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||
compute_metrics=compute_metrics,
|
||||
tokenizer=tokenizer,
|
||||
@@ -346,6 +362,26 @@ def main():
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
# Prediction
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Predict ***")
|
||||
predictions, labels, metrics = trainer.predict(test_dataset)
|
||||
|
||||
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
|
||||
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
|
||||
|
||||
trainer.log_metrics("test", metrics)
|
||||
trainer.save_metrics("test", metrics)
|
||||
|
||||
predictions = np.argmax(predictions, axis=1)
|
||||
output_test_file = os.path.join(training_args.output_dir, "test_predictions.txt")
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_test_file, "w") as writer:
|
||||
writer.write("index\tprediction\n")
|
||||
for index, item in enumerate(predictions):
|
||||
item = label_list[item]
|
||||
writer.write(f"{index}\t{item}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user