[Examples] Fixes inconsistency around eval vs val and predict vs test (#11380)
* added changes for uniformity * modified files * corrected typo * fixed qa scripts * fix typos * fixed predict typo in qa no trainer * fixed test file * reverted trainer changes * reverted trainer changes in custom exmaples * updated readme * added changes in deepspeed test * added changes for predict and eval
This commit is contained in:
@@ -50,8 +50,8 @@ For example here is how to truncate all three splits to just 50 samples each:
|
||||
```
|
||||
examples/pytorch/token-classification/run_ner.py \
|
||||
--max_train_samples 50 \
|
||||
--max_val_samples 50 \
|
||||
--max_test_samples 50 \
|
||||
--max_eval_samples 50 \
|
||||
--max_predict_samples 50 \
|
||||
[...]
|
||||
```
|
||||
|
||||
|
||||
@@ -126,10 +126,10 @@ class DataTrainingArguments:
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_val_samples: Optional[int] = field(
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
@@ -397,8 +397,8 @@ def main():
|
||||
if "validation" not in tokenized_datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = lm_datasets["validation"]
|
||||
if data_args.max_val_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
if data_args.max_eval_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Trainer(
|
||||
@@ -439,8 +439,8 @@ def main():
|
||||
|
||||
metrics = trainer.evaluate()
|
||||
|
||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
perplexity = math.exp(metrics["eval_loss"])
|
||||
metrics["perplexity"] = perplexity
|
||||
|
||||
|
||||
@@ -157,10 +157,10 @@ class DataTrainingArguments:
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_val_samples: Optional[int] = field(
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
@@ -419,8 +419,8 @@ def main():
|
||||
if "validation" not in tokenized_datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = tokenized_datasets["validation"]
|
||||
if data_args.max_val_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
if data_args.max_eval_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
|
||||
# Data collator
|
||||
# This one will take care of randomly masking the tokens.
|
||||
@@ -468,8 +468,8 @@ def main():
|
||||
|
||||
metrics = trainer.evaluate()
|
||||
|
||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
perplexity = math.exp(metrics["eval_loss"])
|
||||
metrics["perplexity"] = perplexity
|
||||
|
||||
|
||||
@@ -154,10 +154,10 @@ class DataTrainingArguments:
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_val_samples: Optional[int] = field(
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
@@ -397,8 +397,8 @@ def main():
|
||||
if "validation" not in tokenized_datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = tokenized_datasets["validation"]
|
||||
if data_args.max_val_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
if data_args.max_eval_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
|
||||
# Data collator
|
||||
data_collator = DataCollatorForPermutationLanguageModeling(
|
||||
@@ -444,8 +444,8 @@ def main():
|
||||
|
||||
metrics = trainer.evaluate()
|
||||
|
||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
perplexity = math.exp(metrics["eval_loss"])
|
||||
metrics["perplexity"] = perplexity
|
||||
|
||||
|
||||
@@ -127,10 +127,10 @@ class DataTrainingArguments:
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_val_samples: Optional[int] = field(
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
@@ -363,8 +363,8 @@ def main():
|
||||
if "validation" not in datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = datasets["validation"]
|
||||
if data_args.max_val_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
if data_args.max_eval_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
eval_dataset = eval_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
@@ -422,8 +422,8 @@ def main():
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
metrics = trainer.evaluate()
|
||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
@@ -133,17 +133,17 @@ class DataTrainingArguments:
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_val_samples: Optional[int] = field(
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_test_samples: Optional[int] = field(
|
||||
max_predict_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
@@ -468,9 +468,9 @@ def main():
|
||||
if "validation" not in datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_examples = datasets["validation"]
|
||||
if data_args.max_val_samples is not None:
|
||||
if data_args.max_eval_samples is not None:
|
||||
# We will select sample from whole data
|
||||
eval_examples = eval_examples.select(range(data_args.max_val_samples))
|
||||
eval_examples = eval_examples.select(range(data_args.max_eval_samples))
|
||||
# Validation Feature Creation
|
||||
eval_dataset = eval_examples.map(
|
||||
prepare_validation_features,
|
||||
@@ -479,28 +479,28 @@ def main():
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
if data_args.max_val_samples is not None:
|
||||
if data_args.max_eval_samples is not None:
|
||||
# During Feature creation dataset samples might increase, we will select required samples again
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
|
||||
if training_args.do_predict:
|
||||
if "test" not in datasets:
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
test_examples = datasets["test"]
|
||||
if data_args.max_test_samples is not None:
|
||||
predict_examples = datasets["test"]
|
||||
if data_args.max_predict_samples is not None:
|
||||
# We will select sample from whole data
|
||||
test_examples = test_examples.select(range(data_args.max_test_samples))
|
||||
# Test Feature Creation
|
||||
test_dataset = test_examples.map(
|
||||
predict_examples = predict_examples.select(range(data_args.max_predict_samples))
|
||||
# Predict Feature Creation
|
||||
predict_dataset = predict_examples.map(
|
||||
prepare_validation_features,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
if data_args.max_test_samples is not None:
|
||||
if data_args.max_predict_samples is not None:
|
||||
# During Feature creation dataset samples might increase, we will select required samples again
|
||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
||||
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||
|
||||
# Data collator
|
||||
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
|
||||
@@ -581,8 +581,8 @@ def main():
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = trainer.evaluate()
|
||||
|
||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
@@ -590,14 +590,16 @@ def main():
|
||||
# Prediction
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Predict ***")
|
||||
results = trainer.predict(test_dataset, test_examples)
|
||||
results = trainer.predict(predict_dataset, predict_examples)
|
||||
metrics = results.metrics
|
||||
|
||||
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
|
||||
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
|
||||
max_predict_samples = (
|
||||
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
|
||||
)
|
||||
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
|
||||
|
||||
trainer.log_metrics("test", metrics)
|
||||
trainer.save_metrics("test", metrics)
|
||||
trainer.log_metrics("predict", metrics)
|
||||
trainer.save_metrics("predict", metrics)
|
||||
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub()
|
||||
|
||||
@@ -132,17 +132,17 @@ class DataTrainingArguments:
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_val_samples: Optional[int] = field(
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_test_samples: Optional[int] = field(
|
||||
max_predict_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
@@ -504,9 +504,9 @@ def main():
|
||||
if "validation" not in datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_examples = datasets["validation"]
|
||||
if data_args.max_val_samples is not None:
|
||||
if data_args.max_eval_samples is not None:
|
||||
# Selecting Eval Samples from Dataset
|
||||
eval_examples = eval_examples.select(range(data_args.max_val_samples))
|
||||
eval_examples = eval_examples.select(range(data_args.max_eval_samples))
|
||||
# Create Features from Eval Dataset
|
||||
eval_dataset = eval_examples.map(
|
||||
prepare_validation_features,
|
||||
@@ -515,28 +515,28 @@ def main():
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
if data_args.max_val_samples is not None:
|
||||
if data_args.max_eval_samples is not None:
|
||||
# Selecting Samples from Dataset again since Feature Creation might increase samples size
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
|
||||
if training_args.do_predict:
|
||||
if "test" not in datasets:
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
test_examples = datasets["test"]
|
||||
if data_args.max_test_samples is not None:
|
||||
predict_examples = datasets["test"]
|
||||
if data_args.max_predict_samples is not None:
|
||||
# We will select sample from whole data
|
||||
test_examples = test_examples.select(range(data_args.max_test_samples))
|
||||
predict_examples = predict_examples.select(range(data_args.max_predict_samples))
|
||||
# Test Feature Creation
|
||||
test_dataset = test_examples.map(
|
||||
predict_dataset = predict_examples.map(
|
||||
prepare_validation_features,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
if data_args.max_test_samples is not None:
|
||||
if data_args.max_predict_samples is not None:
|
||||
# During Feature creation dataset samples might increase, we will select required samples again
|
||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
||||
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||
|
||||
# Data collator
|
||||
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
|
||||
@@ -620,8 +620,8 @@ def main():
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = trainer.evaluate()
|
||||
|
||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
@@ -629,14 +629,16 @@ def main():
|
||||
# Prediction
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Predict ***")
|
||||
results = trainer.predict(test_dataset, test_examples)
|
||||
results = trainer.predict(predict_dataset, predict_examples)
|
||||
metrics = results.metrics
|
||||
|
||||
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
|
||||
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
|
||||
max_predict_samples = (
|
||||
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
|
||||
)
|
||||
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
|
||||
|
||||
trainer.log_metrics("test", metrics)
|
||||
trainer.save_metrics("test", metrics)
|
||||
trainer.log_metrics("predict", metrics)
|
||||
trainer.save_metrics("predict", metrics)
|
||||
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub()
|
||||
|
||||
@@ -183,20 +183,20 @@ def parse_args():
|
||||
"value if set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_val_samples",
|
||||
"--max_eval_samples",
|
||||
type=int,
|
||||
default=None,
|
||||
help="For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
help="For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_test_samples",
|
||||
"--max_predict_samples",
|
||||
type=int,
|
||||
default=None,
|
||||
help="For debugging purposes or quicker training, truncate the number of test examples to this",
|
||||
help="For debugging purposes or quicker training, truncate the number of prediction examples to this",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -481,9 +481,9 @@ def main():
|
||||
if "validation" not in raw_datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_examples = raw_datasets["validation"]
|
||||
if args.max_val_samples is not None:
|
||||
if args.max_eval_samples is not None:
|
||||
# We will select sample from whole data
|
||||
eval_examples = eval_examples.select(range(args.max_val_samples))
|
||||
eval_examples = eval_examples.select(range(args.max_eval_samples))
|
||||
# Validation Feature Creation
|
||||
eval_dataset = eval_examples.map(
|
||||
prepare_validation_features,
|
||||
@@ -493,28 +493,28 @@ def main():
|
||||
load_from_cache_file=not args.overwrite_cache,
|
||||
)
|
||||
|
||||
if args.max_val_samples is not None:
|
||||
if args.max_eval_samples is not None:
|
||||
# During Feature creation dataset samples might increase, we will select required samples again
|
||||
eval_dataset = eval_dataset.select(range(args.max_val_samples))
|
||||
eval_dataset = eval_dataset.select(range(args.max_eval_samples))
|
||||
|
||||
if args.do_predict:
|
||||
if "test" not in raw_datasets:
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
test_examples = raw_datasets["test"]
|
||||
if args.max_test_samples is not None:
|
||||
predict_examples = raw_datasets["test"]
|
||||
if args.max_predict_samples is not None:
|
||||
# We will select sample from whole data
|
||||
test_examples = test_examples.select(range(args.max_test_samples))
|
||||
# Test Feature Creation
|
||||
test_dataset = test_examples.map(
|
||||
predict_examples = predict_examples.select(range(args.max_predict_samples))
|
||||
# Predict Feature Creation
|
||||
predict_dataset = predict_examples.map(
|
||||
prepare_validation_features,
|
||||
batched=True,
|
||||
num_proc=args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not args.overwrite_cache,
|
||||
)
|
||||
if args.max_test_samples is not None:
|
||||
if args.max_predict_samples is not None:
|
||||
# During Feature creation dataset samples might increase, we will select required samples again
|
||||
test_dataset = test_dataset.select(range(args.max_test_samples))
|
||||
predict_dataset = predict_dataset.select(range(args.max_predict_samples))
|
||||
|
||||
# Log a few random samples from the training set:
|
||||
for index in random.sample(range(len(train_dataset)), 3):
|
||||
@@ -539,9 +539,9 @@ def main():
|
||||
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
|
||||
|
||||
if args.do_predict:
|
||||
test_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
|
||||
test_dataloader = DataLoader(
|
||||
test_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
|
||||
predict_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
|
||||
predict_dataloader = DataLoader(
|
||||
predict_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
|
||||
)
|
||||
|
||||
# Post-processing:
|
||||
@@ -737,7 +737,7 @@ def main():
|
||||
all_end_top_log_probs = []
|
||||
all_end_top_index = []
|
||||
all_cls_logits = []
|
||||
for step, batch in enumerate(test_dataloader):
|
||||
for step, batch in enumerate(predict_dataloader):
|
||||
with torch.no_grad():
|
||||
outputs = model(**batch)
|
||||
start_top_log_probs = outputs.start_top_log_probs
|
||||
@@ -762,10 +762,10 @@ def main():
|
||||
max_len = max([x.shape[1] for x in all_end_top_log_probs]) # Get the max_length of the tensor
|
||||
|
||||
# concatenate all numpy arrays collected above
|
||||
start_top_log_probs_concat = create_and_fill_np_array(all_start_top_log_probs, test_dataset, max_len)
|
||||
start_top_index_concat = create_and_fill_np_array(all_start_top_index, test_dataset, max_len)
|
||||
end_top_log_probs_concat = create_and_fill_np_array(all_end_top_log_probs, test_dataset, max_len)
|
||||
end_top_index_concat = create_and_fill_np_array(all_end_top_index, test_dataset, max_len)
|
||||
start_top_log_probs_concat = create_and_fill_np_array(all_start_top_log_probs, predict_dataset, max_len)
|
||||
start_top_index_concat = create_and_fill_np_array(all_start_top_index, predict_dataset, max_len)
|
||||
end_top_log_probs_concat = create_and_fill_np_array(all_end_top_log_probs, predict_dataset, max_len)
|
||||
end_top_index_concat = create_and_fill_np_array(all_end_top_index, predict_dataset, max_len)
|
||||
all_cls_logits = np.concatenate(all_cls_logits, axis=0)
|
||||
|
||||
# delete the list of numpy arrays
|
||||
@@ -774,7 +774,7 @@ def main():
|
||||
del end_top_log_probs
|
||||
del end_top_index
|
||||
|
||||
test_dataset.set_format(type=None, columns=list(test_dataset.features.keys()))
|
||||
predict_dataset.set_format(type=None, columns=list(predict_dataset.features.keys()))
|
||||
outputs_numpy = (
|
||||
start_top_log_probs_concat,
|
||||
start_top_index_concat,
|
||||
@@ -783,9 +783,9 @@ def main():
|
||||
cls_logits,
|
||||
)
|
||||
|
||||
prediction = post_processing_function(test_examples, test_dataset, outputs_numpy)
|
||||
test_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
|
||||
logger.info(f"Test metrics: {test_metric}")
|
||||
prediction = post_processing_function(predict_examples, predict_dataset, outputs_numpy)
|
||||
predict_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
|
||||
logger.info(f"Predict metrics: {predict_metric}")
|
||||
|
||||
if args.output_dir is not None:
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -205,20 +205,20 @@ def parse_args():
|
||||
"value if set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_val_samples",
|
||||
"--max_eval_samples",
|
||||
type=int,
|
||||
default=None,
|
||||
help="For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
help="For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_test_samples",
|
||||
"--max_predict_samples",
|
||||
type=int,
|
||||
default=None,
|
||||
help="For debugging purposes or quicker training, truncate the number of test examples to this",
|
||||
help="For debugging purposes or quicker training, truncate the number of prediction examples to this",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
@@ -486,9 +486,9 @@ def main():
|
||||
if "validation" not in raw_datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_examples = raw_datasets["validation"]
|
||||
if args.max_val_samples is not None:
|
||||
if args.max_eval_samples is not None:
|
||||
# We will select sample from whole data
|
||||
eval_examples = eval_examples.select(range(args.max_val_samples))
|
||||
eval_examples = eval_examples.select(range(args.max_eval_samples))
|
||||
# Validation Feature Creation
|
||||
eval_dataset = eval_examples.map(
|
||||
prepare_validation_features,
|
||||
@@ -498,28 +498,28 @@ def main():
|
||||
load_from_cache_file=not args.overwrite_cache,
|
||||
)
|
||||
|
||||
if args.max_val_samples is not None:
|
||||
if args.max_eval_samples is not None:
|
||||
# During Feature creation dataset samples might increase, we will select required samples again
|
||||
eval_dataset = eval_dataset.select(range(args.max_val_samples))
|
||||
eval_dataset = eval_dataset.select(range(args.max_eval_samples))
|
||||
|
||||
if args.do_predict:
|
||||
if "test" not in raw_datasets:
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
test_examples = raw_datasets["test"]
|
||||
if args.max_test_samples is not None:
|
||||
predict_examples = raw_datasets["test"]
|
||||
if args.max_predict_samples is not None:
|
||||
# We will select sample from whole data
|
||||
test_examples = test_examples.select(range(args.max_test_samples))
|
||||
# Test Feature Creation
|
||||
test_dataset = test_examples.map(
|
||||
predict_examples = predict_examples.select(range(args.max_predict_samples))
|
||||
# Predict Feature Creation
|
||||
predict_dataset = predict_examples.map(
|
||||
prepare_validation_features,
|
||||
batched=True,
|
||||
num_proc=args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
load_from_cache_file=not args.overwrite_cache,
|
||||
)
|
||||
if args.max_test_samples is not None:
|
||||
if args.max_predict_samples is not None:
|
||||
# During Feature creation dataset samples might increase, we will select required samples again
|
||||
test_dataset = test_dataset.select(range(args.max_test_samples))
|
||||
predict_dataset = predict_dataset.select(range(args.max_predict_samples))
|
||||
|
||||
# Log a few random samples from the training set:
|
||||
for index in random.sample(range(len(train_dataset)), 3):
|
||||
@@ -544,9 +544,9 @@ def main():
|
||||
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
|
||||
|
||||
if args.do_predict:
|
||||
test_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
|
||||
test_dataloader = DataLoader(
|
||||
test_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
|
||||
predict_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
|
||||
predict_dataloader = DataLoader(
|
||||
predict_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
|
||||
)
|
||||
|
||||
# Post-processing:
|
||||
@@ -714,7 +714,7 @@ def main():
|
||||
if args.do_predict:
|
||||
all_start_logits = []
|
||||
all_end_logits = []
|
||||
for step, batch in enumerate(test_dataloader):
|
||||
for step, batch in enumerate(predict_dataloader):
|
||||
with torch.no_grad():
|
||||
outputs = model(**batch)
|
||||
start_logits = outputs.start_logits
|
||||
@@ -729,19 +729,19 @@ def main():
|
||||
|
||||
max_len = max([x.shape[1] for x in all_start_logits]) # Get the max_length of the tensor
|
||||
# concatenate the numpy array
|
||||
start_logits_concat = create_and_fill_np_array(all_start_logits, test_dataset, max_len)
|
||||
end_logits_concat = create_and_fill_np_array(all_end_logits, test_dataset, max_len)
|
||||
start_logits_concat = create_and_fill_np_array(all_start_logits, predict_dataset, max_len)
|
||||
end_logits_concat = create_and_fill_np_array(all_end_logits, predict_dataset, max_len)
|
||||
|
||||
# delete the list of numpy arrays
|
||||
del all_start_logits
|
||||
del all_end_logits
|
||||
|
||||
# Now we need to add extra columns which we removed for post processing
|
||||
test_dataset.set_format(type=None, columns=list(test_dataset.features.keys()))
|
||||
predict_dataset.set_format(type=None, columns=list(predict_dataset.features.keys()))
|
||||
outputs_numpy = (start_logits_concat, end_logits_concat)
|
||||
prediction = post_processing_function(test_examples, test_dataset, outputs_numpy)
|
||||
eval_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
|
||||
logger.info(f"Test metrics: {eval_metric}")
|
||||
prediction = post_processing_function(predict_examples, predict_dataset, outputs_numpy)
|
||||
predict_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
|
||||
logger.info(f"Predict metrics: {predict_metric}")
|
||||
|
||||
if args.output_dir is not None:
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -66,16 +66,16 @@ class QuestionAnsweringTrainer(Trainer):
|
||||
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
|
||||
return metrics
|
||||
|
||||
def predict(self, test_dataset, test_examples, ignore_keys=None):
|
||||
test_dataloader = self.get_test_dataloader(test_dataset)
|
||||
def predict(self, predict_dataset, predict_examples, ignore_keys=None):
|
||||
predict_dataloader = self.get_test_dataloader(predict_dataset)
|
||||
|
||||
# Temporarily disable metric computation, we will do it in the loop here.
|
||||
compute_metrics = self.compute_metrics
|
||||
self.compute_metrics = None
|
||||
try:
|
||||
output = self.prediction_loop(
|
||||
test_dataloader,
|
||||
description="Evaluation",
|
||||
predict_dataloader,
|
||||
description="Prediction",
|
||||
# No point gathering the predictions if there are no metrics, otherwise we defer to
|
||||
# self.args.prediction_loss_only
|
||||
prediction_loss_only=True if compute_metrics is None else None,
|
||||
@@ -87,7 +87,7 @@ class QuestionAnsweringTrainer(Trainer):
|
||||
if self.post_process_function is None or self.compute_metrics is None:
|
||||
return output
|
||||
|
||||
eval_preds = self.post_process_function(test_examples, test_dataset, output.predictions, "test")
|
||||
metrics = self.compute_metrics(eval_preds)
|
||||
predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict")
|
||||
metrics = self.compute_metrics(predictions)
|
||||
|
||||
return PredictionOutput(predictions=eval_preds.predictions, label_ids=eval_preds.label_ids, metrics=metrics)
|
||||
return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics)
|
||||
|
||||
@@ -178,17 +178,17 @@ class DataTrainingArguments:
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_val_samples: Optional[int] = field(
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_test_samples: Optional[int] = field(
|
||||
max_predict_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
@@ -438,8 +438,8 @@ def main():
|
||||
if "validation" not in datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = datasets["validation"]
|
||||
if data_args.max_val_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
if data_args.max_eval_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
eval_dataset = eval_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
@@ -452,10 +452,10 @@ def main():
|
||||
max_target_length = data_args.val_max_target_length
|
||||
if "test" not in datasets:
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
test_dataset = datasets["test"]
|
||||
if data_args.max_test_samples is not None:
|
||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
||||
test_dataset = test_dataset.map(
|
||||
predict_dataset = datasets["test"]
|
||||
if data_args.max_predict_samples is not None:
|
||||
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||
predict_dataset = predict_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
@@ -547,37 +547,39 @@ def main():
|
||||
metrics = trainer.evaluate(
|
||||
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
|
||||
)
|
||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Test ***")
|
||||
logger.info("*** Predict ***")
|
||||
|
||||
test_results = trainer.predict(
|
||||
test_dataset,
|
||||
metric_key_prefix="test",
|
||||
predict_results = trainer.predict(
|
||||
predict_dataset,
|
||||
metric_key_prefix="predict",
|
||||
max_length=data_args.val_max_target_length,
|
||||
num_beams=data_args.num_beams,
|
||||
)
|
||||
metrics = test_results.metrics
|
||||
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
|
||||
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
|
||||
metrics = predict_results.metrics
|
||||
max_predict_samples = (
|
||||
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
|
||||
)
|
||||
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
|
||||
|
||||
trainer.log_metrics("test", metrics)
|
||||
trainer.save_metrics("test", metrics)
|
||||
trainer.log_metrics("predict", metrics)
|
||||
trainer.save_metrics("predict", metrics)
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
if training_args.predict_with_generate:
|
||||
test_preds = tokenizer.batch_decode(
|
||||
test_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
predictions = tokenizer.batch_decode(
|
||||
predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
test_preds = [pred.strip() for pred in test_preds]
|
||||
output_test_preds_file = os.path.join(training_args.output_dir, "test_generations.txt")
|
||||
with open(output_test_preds_file, "w") as writer:
|
||||
writer.write("\n".join(test_preds))
|
||||
predictions = [pred.strip() for pred in predictions]
|
||||
output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
|
||||
with open(output_prediction_file, "w") as writer:
|
||||
writer.write("\n".join(predictions))
|
||||
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub()
|
||||
|
||||
@@ -100,17 +100,17 @@ class DataTrainingArguments:
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_val_samples: Optional[int] = field(
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_test_samples: Optional[int] = field(
|
||||
max_predict_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
@@ -390,15 +390,15 @@ def main():
|
||||
if "validation" not in datasets and "validation_matched" not in datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
|
||||
if data_args.max_val_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
if data_args.max_eval_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
|
||||
if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
|
||||
if "test" not in datasets and "test_matched" not in datasets:
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]
|
||||
if data_args.max_test_samples is not None:
|
||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
||||
predict_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]
|
||||
if data_args.max_predict_samples is not None:
|
||||
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||
|
||||
# Log a few random samples from the training set:
|
||||
if training_args.do_train:
|
||||
@@ -483,32 +483,34 @@ def main():
|
||||
for eval_dataset, task in zip(eval_datasets, tasks):
|
||||
metrics = trainer.evaluate(eval_dataset=eval_dataset)
|
||||
|
||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
max_eval_samples = (
|
||||
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Test ***")
|
||||
logger.info("*** Predict ***")
|
||||
|
||||
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
||||
tasks = [data_args.task_name]
|
||||
test_datasets = [test_dataset]
|
||||
predict_datasets = [predict_dataset]
|
||||
if data_args.task_name == "mnli":
|
||||
tasks.append("mnli-mm")
|
||||
test_datasets.append(datasets["test_mismatched"])
|
||||
predict_datasets.append(datasets["test_mismatched"])
|
||||
|
||||
for test_dataset, task in zip(test_datasets, tasks):
|
||||
for predict_dataset, task in zip(predict_datasets, tasks):
|
||||
# Removing the `label` columns because it contains -1 and Trainer won't like that.
|
||||
test_dataset.remove_columns_("label")
|
||||
predictions = trainer.predict(test_dataset=test_dataset).predictions
|
||||
predict_dataset.remove_columns_("label")
|
||||
predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions
|
||||
predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
|
||||
|
||||
output_test_file = os.path.join(training_args.output_dir, f"test_results_{task}.txt")
|
||||
output_predict_file = os.path.join(training_args.output_dir, f"predict_results_{task}.txt")
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_test_file, "w") as writer:
|
||||
logger.info(f"***** Test results {task} *****")
|
||||
with open(output_predict_file, "w") as writer:
|
||||
logger.info(f"***** Predict results {task} *****")
|
||||
writer.write("index\tprediction\n")
|
||||
for index, item in enumerate(predictions):
|
||||
if is_regression:
|
||||
|
||||
@@ -84,17 +84,17 @@ class DataTrainingArguments:
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_val_samples: Optional[int] = field(
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_test_samples: Optional[int] = field(
|
||||
max_predict_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
@@ -221,8 +221,8 @@ def main():
|
||||
label_list = eval_dataset.features["label"].names
|
||||
|
||||
if training_args.do_predict:
|
||||
test_dataset = load_dataset("xnli", model_args.language, split="test", cache_dir=model_args.cache_dir)
|
||||
label_list = test_dataset.features["label"].names
|
||||
predict_dataset = load_dataset("xnli", model_args.language, split="test", cache_dir=model_args.cache_dir)
|
||||
label_list = predict_dataset.features["label"].names
|
||||
|
||||
# Labels
|
||||
num_labels = len(label_list)
|
||||
@@ -286,8 +286,8 @@ def main():
|
||||
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
||||
|
||||
if training_args.do_eval:
|
||||
if data_args.max_val_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
if data_args.max_eval_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
eval_dataset = eval_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
@@ -295,9 +295,9 @@ def main():
|
||||
)
|
||||
|
||||
if training_args.do_predict:
|
||||
if data_args.max_test_samples is not None:
|
||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
||||
test_dataset = test_dataset.map(
|
||||
if data_args.max_predict_samples is not None:
|
||||
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||
predict_dataset = predict_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
@@ -360,8 +360,8 @@ def main():
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = trainer.evaluate(eval_dataset=eval_dataset)
|
||||
|
||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
@@ -369,18 +369,20 @@ def main():
|
||||
# Prediction
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Predict ***")
|
||||
predictions, labels, metrics = trainer.predict(test_dataset)
|
||||
predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict")
|
||||
|
||||
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
|
||||
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
|
||||
max_predict_samples = (
|
||||
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
|
||||
)
|
||||
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
|
||||
|
||||
trainer.log_metrics("test", metrics)
|
||||
trainer.save_metrics("test", metrics)
|
||||
trainer.log_metrics("predict", metrics)
|
||||
trainer.save_metrics("predict", metrics)
|
||||
|
||||
predictions = np.argmax(predictions, axis=1)
|
||||
output_test_file = os.path.join(training_args.output_dir, "test_predictions.txt")
|
||||
output_predict_file = os.path.join(training_args.output_dir, "predictions.txt")
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_test_file, "w") as writer:
|
||||
with open(output_predict_file, "w") as writer:
|
||||
writer.write("index\tprediction\n")
|
||||
for index, item in enumerate(predictions):
|
||||
item = label_list[item]
|
||||
|
||||
@@ -128,17 +128,17 @@ class DataTrainingArguments:
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_val_samples: Optional[int] = field(
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_test_samples: Optional[int] = field(
|
||||
max_predict_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
@@ -363,8 +363,8 @@ def main():
|
||||
if "validation" not in datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = datasets["validation"]
|
||||
if data_args.max_val_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
if data_args.max_eval_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
eval_dataset = eval_dataset.map(
|
||||
tokenize_and_align_labels,
|
||||
batched=True,
|
||||
@@ -375,10 +375,10 @@ def main():
|
||||
if training_args.do_predict:
|
||||
if "test" not in datasets:
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
test_dataset = datasets["test"]
|
||||
if data_args.max_test_samples is not None:
|
||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
||||
test_dataset = test_dataset.map(
|
||||
predict_dataset = datasets["test"]
|
||||
if data_args.max_predict_samples is not None:
|
||||
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||
predict_dataset = predict_dataset.map(
|
||||
tokenize_and_align_labels,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
@@ -462,8 +462,8 @@ def main():
|
||||
|
||||
metrics = trainer.evaluate()
|
||||
|
||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
@@ -472,7 +472,7 @@ def main():
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Predict ***")
|
||||
|
||||
predictions, labels, metrics = trainer.predict(test_dataset)
|
||||
predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict")
|
||||
predictions = np.argmax(predictions, axis=2)
|
||||
|
||||
# Remove ignored index (special tokens)
|
||||
@@ -481,13 +481,13 @@ def main():
|
||||
for prediction, label in zip(predictions, labels)
|
||||
]
|
||||
|
||||
trainer.log_metrics("test", metrics)
|
||||
trainer.save_metrics("test", metrics)
|
||||
trainer.log_metrics("predict", metrics)
|
||||
trainer.save_metrics("predict", metrics)
|
||||
|
||||
# Save predictions
|
||||
output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt")
|
||||
output_predictions_file = os.path.join(training_args.output_dir, "predictions.txt")
|
||||
if trainer.is_world_process_zero():
|
||||
with open(output_test_predictions_file, "w") as writer:
|
||||
with open(output_predictions_file, "w") as writer:
|
||||
for prediction in true_predictions:
|
||||
writer.write(" ".join(prediction) + "\n")
|
||||
|
||||
|
||||
@@ -167,17 +167,17 @@ class DataTrainingArguments:
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_val_samples: Optional[int] = field(
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_test_samples: Optional[int] = field(
|
||||
max_predict_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
@@ -432,8 +432,8 @@ def main():
|
||||
if "validation" not in datasets:
|
||||
raise ValueError("--do_eval requires a validation dataset")
|
||||
eval_dataset = datasets["validation"]
|
||||
if data_args.max_val_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
if data_args.max_eval_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
eval_dataset = eval_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
@@ -446,10 +446,10 @@ def main():
|
||||
max_target_length = data_args.val_max_target_length
|
||||
if "test" not in datasets:
|
||||
raise ValueError("--do_predict requires a test dataset")
|
||||
test_dataset = datasets["test"]
|
||||
if data_args.max_test_samples is not None:
|
||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
||||
test_dataset = test_dataset.map(
|
||||
predict_dataset = datasets["test"]
|
||||
if data_args.max_predict_samples is not None:
|
||||
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||
predict_dataset = predict_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
@@ -539,37 +539,39 @@ def main():
|
||||
metrics = trainer.evaluate(
|
||||
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
|
||||
)
|
||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
if training_args.do_predict:
|
||||
logger.info("*** Test ***")
|
||||
logger.info("*** Predict ***")
|
||||
|
||||
test_results = trainer.predict(
|
||||
test_dataset,
|
||||
metric_key_prefix="test",
|
||||
predict_results = trainer.predict(
|
||||
predict_dataset,
|
||||
metric_key_prefix="predict",
|
||||
max_length=data_args.val_max_target_length,
|
||||
num_beams=data_args.num_beams,
|
||||
)
|
||||
metrics = test_results.metrics
|
||||
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
|
||||
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
|
||||
metrics = predict_results.metrics
|
||||
max_predict_samples = (
|
||||
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
|
||||
)
|
||||
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
|
||||
|
||||
trainer.log_metrics("test", metrics)
|
||||
trainer.save_metrics("test", metrics)
|
||||
trainer.log_metrics("predict", metrics)
|
||||
trainer.save_metrics("predict", metrics)
|
||||
|
||||
if trainer.is_world_process_zero():
|
||||
if training_args.predict_with_generate:
|
||||
test_preds = tokenizer.batch_decode(
|
||||
test_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
predictions = tokenizer.batch_decode(
|
||||
predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
test_preds = [pred.strip() for pred in test_preds]
|
||||
output_test_preds_file = os.path.join(training_args.output_dir, "test_generations.txt")
|
||||
with open(output_test_preds_file, "w") as writer:
|
||||
writer.write("\n".join(test_preds))
|
||||
predictions = [pred.strip() for pred in predictions]
|
||||
output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
|
||||
with open(output_prediction_file, "w") as writer:
|
||||
writer.write("\n".join(predictions))
|
||||
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub()
|
||||
|
||||
@@ -164,17 +164,17 @@ class DataTrainingArguments:
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_val_samples: Optional[int] = field(
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
max_test_samples: Optional[int] = field(
|
||||
max_predict_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
||||
"help": "For debugging purposes or quicker training, truncate the number of predict examples to this "
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
@@ -468,13 +468,13 @@ def main():
|
||||
|
||||
if "validation" in datasets:
|
||||
eval_dataset = datasets["validation"]
|
||||
if data_args.max_val_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
if data_args.max_eval_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||
|
||||
if "test" in datasets:
|
||||
test_dataset = datasets["test"]
|
||||
if data_args.max_test_samples is not None:
|
||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
||||
predict_dataset = datasets["test"]
|
||||
if data_args.max_predict_samples is not None:
|
||||
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -513,15 +513,15 @@ def main():
|
||||
|
||||
# region Prediction
|
||||
if "test" in datasets:
|
||||
logger.info("Doing predictions on test dataset...")
|
||||
logger.info("Doing predictions on Predict dataset...")
|
||||
|
||||
test_dataset = DataSequence(
|
||||
test_dataset, non_label_column_names, batch_size=training_args.per_device_eval_batch_size, labels=False
|
||||
predict_dataset = DataSequence(
|
||||
predict_dataset, non_label_column_names, batch_size=training_args.per_device_eval_batch_size, labels=False
|
||||
)
|
||||
predictions = model.predict(test_dataset)["logits"]
|
||||
predictions = model.predict(predict_dataset)["logits"]
|
||||
predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
|
||||
output_test_file = os.path.join(training_args.output_dir, "test_results.txt")
|
||||
with open(output_test_file, "w") as writer:
|
||||
output_predict_file = os.path.join(training_args.output_dir, "predict_results.txt")
|
||||
with open(output_predict_file, "w") as writer:
|
||||
writer.write("index\tprediction\n")
|
||||
for index, item in enumerate(predictions):
|
||||
if is_regression:
|
||||
@@ -529,7 +529,7 @@ def main():
|
||||
else:
|
||||
item = model.config.id2label[item]
|
||||
writer.write(f"{index}\t{item}\n")
|
||||
logger.info(f"Wrote predictions to {output_test_file}!")
|
||||
logger.info(f"Wrote predictions to {output_predict_file}!")
|
||||
# endregion
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user