[Examples] Fixes inconsistency around eval vs val and predict vs test (#11380)
* added changes for uniformity * modified files * corrected typo * fixed qa scripts * fix typos * fixed predict typo in qa no trainer * fixed test file * reverted trainer changes * reverted trainer changes in custom exmaples * updated readme * added changes in deepspeed test * added changes for predict and eval
This commit is contained in:
@@ -50,8 +50,8 @@ For example here is how to truncate all three splits to just 50 samples each:
|
|||||||
```
|
```
|
||||||
examples/pytorch/token-classification/run_ner.py \
|
examples/pytorch/token-classification/run_ner.py \
|
||||||
--max_train_samples 50 \
|
--max_train_samples 50 \
|
||||||
--max_val_samples 50 \
|
--max_eval_samples 50 \
|
||||||
--max_test_samples 50 \
|
--max_predict_samples 50 \
|
||||||
[...]
|
[...]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -126,10 +126,10 @@ class DataTrainingArguments:
|
|||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_val_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -397,8 +397,8 @@ def main():
|
|||||||
if "validation" not in tokenized_datasets:
|
if "validation" not in tokenized_datasets:
|
||||||
raise ValueError("--do_eval requires a validation dataset")
|
raise ValueError("--do_eval requires a validation dataset")
|
||||||
eval_dataset = lm_datasets["validation"]
|
eval_dataset = lm_datasets["validation"]
|
||||||
if data_args.max_val_samples is not None:
|
if data_args.max_eval_samples is not None:
|
||||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
@@ -439,8 +439,8 @@ def main():
|
|||||||
|
|
||||||
metrics = trainer.evaluate()
|
metrics = trainer.evaluate()
|
||||||
|
|
||||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||||
perplexity = math.exp(metrics["eval_loss"])
|
perplexity = math.exp(metrics["eval_loss"])
|
||||||
metrics["perplexity"] = perplexity
|
metrics["perplexity"] = perplexity
|
||||||
|
|
||||||
|
|||||||
@@ -157,10 +157,10 @@ class DataTrainingArguments:
|
|||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_val_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -419,8 +419,8 @@ def main():
|
|||||||
if "validation" not in tokenized_datasets:
|
if "validation" not in tokenized_datasets:
|
||||||
raise ValueError("--do_eval requires a validation dataset")
|
raise ValueError("--do_eval requires a validation dataset")
|
||||||
eval_dataset = tokenized_datasets["validation"]
|
eval_dataset = tokenized_datasets["validation"]
|
||||||
if data_args.max_val_samples is not None:
|
if data_args.max_eval_samples is not None:
|
||||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||||
|
|
||||||
# Data collator
|
# Data collator
|
||||||
# This one will take care of randomly masking the tokens.
|
# This one will take care of randomly masking the tokens.
|
||||||
@@ -468,8 +468,8 @@ def main():
|
|||||||
|
|
||||||
metrics = trainer.evaluate()
|
metrics = trainer.evaluate()
|
||||||
|
|
||||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||||
perplexity = math.exp(metrics["eval_loss"])
|
perplexity = math.exp(metrics["eval_loss"])
|
||||||
metrics["perplexity"] = perplexity
|
metrics["perplexity"] = perplexity
|
||||||
|
|
||||||
|
|||||||
@@ -154,10 +154,10 @@ class DataTrainingArguments:
|
|||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_val_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -397,8 +397,8 @@ def main():
|
|||||||
if "validation" not in tokenized_datasets:
|
if "validation" not in tokenized_datasets:
|
||||||
raise ValueError("--do_eval requires a validation dataset")
|
raise ValueError("--do_eval requires a validation dataset")
|
||||||
eval_dataset = tokenized_datasets["validation"]
|
eval_dataset = tokenized_datasets["validation"]
|
||||||
if data_args.max_val_samples is not None:
|
if data_args.max_eval_samples is not None:
|
||||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||||
|
|
||||||
# Data collator
|
# Data collator
|
||||||
data_collator = DataCollatorForPermutationLanguageModeling(
|
data_collator = DataCollatorForPermutationLanguageModeling(
|
||||||
@@ -444,8 +444,8 @@ def main():
|
|||||||
|
|
||||||
metrics = trainer.evaluate()
|
metrics = trainer.evaluate()
|
||||||
|
|
||||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||||
perplexity = math.exp(metrics["eval_loss"])
|
perplexity = math.exp(metrics["eval_loss"])
|
||||||
metrics["perplexity"] = perplexity
|
metrics["perplexity"] = perplexity
|
||||||
|
|
||||||
|
|||||||
@@ -127,10 +127,10 @@ class DataTrainingArguments:
|
|||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_val_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -363,8 +363,8 @@ def main():
|
|||||||
if "validation" not in datasets:
|
if "validation" not in datasets:
|
||||||
raise ValueError("--do_eval requires a validation dataset")
|
raise ValueError("--do_eval requires a validation dataset")
|
||||||
eval_dataset = datasets["validation"]
|
eval_dataset = datasets["validation"]
|
||||||
if data_args.max_val_samples is not None:
|
if data_args.max_eval_samples is not None:
|
||||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||||
eval_dataset = eval_dataset.map(
|
eval_dataset = eval_dataset.map(
|
||||||
preprocess_function,
|
preprocess_function,
|
||||||
batched=True,
|
batched=True,
|
||||||
@@ -422,8 +422,8 @@ def main():
|
|||||||
logger.info("*** Evaluate ***")
|
logger.info("*** Evaluate ***")
|
||||||
|
|
||||||
metrics = trainer.evaluate()
|
metrics = trainer.evaluate()
|
||||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||||
|
|
||||||
trainer.log_metrics("eval", metrics)
|
trainer.log_metrics("eval", metrics)
|
||||||
trainer.save_metrics("eval", metrics)
|
trainer.save_metrics("eval", metrics)
|
||||||
|
|||||||
@@ -133,17 +133,17 @@ class DataTrainingArguments:
|
|||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_val_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_test_samples: Optional[int] = field(
|
max_predict_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -468,9 +468,9 @@ def main():
|
|||||||
if "validation" not in datasets:
|
if "validation" not in datasets:
|
||||||
raise ValueError("--do_eval requires a validation dataset")
|
raise ValueError("--do_eval requires a validation dataset")
|
||||||
eval_examples = datasets["validation"]
|
eval_examples = datasets["validation"]
|
||||||
if data_args.max_val_samples is not None:
|
if data_args.max_eval_samples is not None:
|
||||||
# We will select sample from whole data
|
# We will select sample from whole data
|
||||||
eval_examples = eval_examples.select(range(data_args.max_val_samples))
|
eval_examples = eval_examples.select(range(data_args.max_eval_samples))
|
||||||
# Validation Feature Creation
|
# Validation Feature Creation
|
||||||
eval_dataset = eval_examples.map(
|
eval_dataset = eval_examples.map(
|
||||||
prepare_validation_features,
|
prepare_validation_features,
|
||||||
@@ -479,28 +479,28 @@ def main():
|
|||||||
remove_columns=column_names,
|
remove_columns=column_names,
|
||||||
load_from_cache_file=not data_args.overwrite_cache,
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
)
|
)
|
||||||
if data_args.max_val_samples is not None:
|
if data_args.max_eval_samples is not None:
|
||||||
# During Feature creation dataset samples might increase, we will select required samples again
|
# During Feature creation dataset samples might increase, we will select required samples again
|
||||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||||
|
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
if "test" not in datasets:
|
if "test" not in datasets:
|
||||||
raise ValueError("--do_predict requires a test dataset")
|
raise ValueError("--do_predict requires a test dataset")
|
||||||
test_examples = datasets["test"]
|
predict_examples = datasets["test"]
|
||||||
if data_args.max_test_samples is not None:
|
if data_args.max_predict_samples is not None:
|
||||||
# We will select sample from whole data
|
# We will select sample from whole data
|
||||||
test_examples = test_examples.select(range(data_args.max_test_samples))
|
predict_examples = predict_examples.select(range(data_args.max_predict_samples))
|
||||||
# Test Feature Creation
|
# Predict Feature Creation
|
||||||
test_dataset = test_examples.map(
|
predict_dataset = predict_examples.map(
|
||||||
prepare_validation_features,
|
prepare_validation_features,
|
||||||
batched=True,
|
batched=True,
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
remove_columns=column_names,
|
remove_columns=column_names,
|
||||||
load_from_cache_file=not data_args.overwrite_cache,
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
)
|
)
|
||||||
if data_args.max_test_samples is not None:
|
if data_args.max_predict_samples is not None:
|
||||||
# During Feature creation dataset samples might increase, we will select required samples again
|
# During Feature creation dataset samples might increase, we will select required samples again
|
||||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||||
|
|
||||||
# Data collator
|
# Data collator
|
||||||
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
|
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
|
||||||
@@ -581,8 +581,8 @@ def main():
|
|||||||
logger.info("*** Evaluate ***")
|
logger.info("*** Evaluate ***")
|
||||||
metrics = trainer.evaluate()
|
metrics = trainer.evaluate()
|
||||||
|
|
||||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||||
|
|
||||||
trainer.log_metrics("eval", metrics)
|
trainer.log_metrics("eval", metrics)
|
||||||
trainer.save_metrics("eval", metrics)
|
trainer.save_metrics("eval", metrics)
|
||||||
@@ -590,14 +590,16 @@ def main():
|
|||||||
# Prediction
|
# Prediction
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
logger.info("*** Predict ***")
|
logger.info("*** Predict ***")
|
||||||
results = trainer.predict(test_dataset, test_examples)
|
results = trainer.predict(predict_dataset, predict_examples)
|
||||||
metrics = results.metrics
|
metrics = results.metrics
|
||||||
|
|
||||||
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
|
max_predict_samples = (
|
||||||
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
|
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
|
||||||
|
)
|
||||||
|
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
|
||||||
|
|
||||||
trainer.log_metrics("test", metrics)
|
trainer.log_metrics("predict", metrics)
|
||||||
trainer.save_metrics("test", metrics)
|
trainer.save_metrics("predict", metrics)
|
||||||
|
|
||||||
if training_args.push_to_hub:
|
if training_args.push_to_hub:
|
||||||
trainer.push_to_hub()
|
trainer.push_to_hub()
|
||||||
|
|||||||
@@ -132,17 +132,17 @@ class DataTrainingArguments:
|
|||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_val_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_test_samples: Optional[int] = field(
|
max_predict_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -504,9 +504,9 @@ def main():
|
|||||||
if "validation" not in datasets:
|
if "validation" not in datasets:
|
||||||
raise ValueError("--do_eval requires a validation dataset")
|
raise ValueError("--do_eval requires a validation dataset")
|
||||||
eval_examples = datasets["validation"]
|
eval_examples = datasets["validation"]
|
||||||
if data_args.max_val_samples is not None:
|
if data_args.max_eval_samples is not None:
|
||||||
# Selecting Eval Samples from Dataset
|
# Selecting Eval Samples from Dataset
|
||||||
eval_examples = eval_examples.select(range(data_args.max_val_samples))
|
eval_examples = eval_examples.select(range(data_args.max_eval_samples))
|
||||||
# Create Features from Eval Dataset
|
# Create Features from Eval Dataset
|
||||||
eval_dataset = eval_examples.map(
|
eval_dataset = eval_examples.map(
|
||||||
prepare_validation_features,
|
prepare_validation_features,
|
||||||
@@ -515,28 +515,28 @@ def main():
|
|||||||
remove_columns=column_names,
|
remove_columns=column_names,
|
||||||
load_from_cache_file=not data_args.overwrite_cache,
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
)
|
)
|
||||||
if data_args.max_val_samples is not None:
|
if data_args.max_eval_samples is not None:
|
||||||
# Selecting Samples from Dataset again since Feature Creation might increase samples size
|
# Selecting Samples from Dataset again since Feature Creation might increase samples size
|
||||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||||
|
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
if "test" not in datasets:
|
if "test" not in datasets:
|
||||||
raise ValueError("--do_predict requires a test dataset")
|
raise ValueError("--do_predict requires a test dataset")
|
||||||
test_examples = datasets["test"]
|
predict_examples = datasets["test"]
|
||||||
if data_args.max_test_samples is not None:
|
if data_args.max_predict_samples is not None:
|
||||||
# We will select sample from whole data
|
# We will select sample from whole data
|
||||||
test_examples = test_examples.select(range(data_args.max_test_samples))
|
predict_examples = predict_examples.select(range(data_args.max_predict_samples))
|
||||||
# Test Feature Creation
|
# Test Feature Creation
|
||||||
test_dataset = test_examples.map(
|
predict_dataset = predict_examples.map(
|
||||||
prepare_validation_features,
|
prepare_validation_features,
|
||||||
batched=True,
|
batched=True,
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
remove_columns=column_names,
|
remove_columns=column_names,
|
||||||
load_from_cache_file=not data_args.overwrite_cache,
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
)
|
)
|
||||||
if data_args.max_test_samples is not None:
|
if data_args.max_predict_samples is not None:
|
||||||
# During Feature creation dataset samples might increase, we will select required samples again
|
# During Feature creation dataset samples might increase, we will select required samples again
|
||||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||||
|
|
||||||
# Data collator
|
# Data collator
|
||||||
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
|
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
|
||||||
@@ -620,8 +620,8 @@ def main():
|
|||||||
logger.info("*** Evaluate ***")
|
logger.info("*** Evaluate ***")
|
||||||
metrics = trainer.evaluate()
|
metrics = trainer.evaluate()
|
||||||
|
|
||||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||||
|
|
||||||
trainer.log_metrics("eval", metrics)
|
trainer.log_metrics("eval", metrics)
|
||||||
trainer.save_metrics("eval", metrics)
|
trainer.save_metrics("eval", metrics)
|
||||||
@@ -629,14 +629,16 @@ def main():
|
|||||||
# Prediction
|
# Prediction
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
logger.info("*** Predict ***")
|
logger.info("*** Predict ***")
|
||||||
results = trainer.predict(test_dataset, test_examples)
|
results = trainer.predict(predict_dataset, predict_examples)
|
||||||
metrics = results.metrics
|
metrics = results.metrics
|
||||||
|
|
||||||
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
|
max_predict_samples = (
|
||||||
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
|
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
|
||||||
|
)
|
||||||
|
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
|
||||||
|
|
||||||
trainer.log_metrics("test", metrics)
|
trainer.log_metrics("predict", metrics)
|
||||||
trainer.save_metrics("test", metrics)
|
trainer.save_metrics("predict", metrics)
|
||||||
|
|
||||||
if training_args.push_to_hub:
|
if training_args.push_to_hub:
|
||||||
trainer.push_to_hub()
|
trainer.push_to_hub()
|
||||||
|
|||||||
@@ -183,20 +183,20 @@ def parse_args():
|
|||||||
"value if set.",
|
"value if set.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_val_samples",
|
"--max_eval_samples",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="For debugging purposes or quicker training, truncate the number of validation examples to this "
|
help="For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set.",
|
"value if set.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
|
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_test_samples",
|
"--max_predict_samples",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="For debugging purposes or quicker training, truncate the number of test examples to this",
|
help="For debugging purposes or quicker training, truncate the number of prediction examples to this",
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -481,9 +481,9 @@ def main():
|
|||||||
if "validation" not in raw_datasets:
|
if "validation" not in raw_datasets:
|
||||||
raise ValueError("--do_eval requires a validation dataset")
|
raise ValueError("--do_eval requires a validation dataset")
|
||||||
eval_examples = raw_datasets["validation"]
|
eval_examples = raw_datasets["validation"]
|
||||||
if args.max_val_samples is not None:
|
if args.max_eval_samples is not None:
|
||||||
# We will select sample from whole data
|
# We will select sample from whole data
|
||||||
eval_examples = eval_examples.select(range(args.max_val_samples))
|
eval_examples = eval_examples.select(range(args.max_eval_samples))
|
||||||
# Validation Feature Creation
|
# Validation Feature Creation
|
||||||
eval_dataset = eval_examples.map(
|
eval_dataset = eval_examples.map(
|
||||||
prepare_validation_features,
|
prepare_validation_features,
|
||||||
@@ -493,28 +493,28 @@ def main():
|
|||||||
load_from_cache_file=not args.overwrite_cache,
|
load_from_cache_file=not args.overwrite_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.max_val_samples is not None:
|
if args.max_eval_samples is not None:
|
||||||
# During Feature creation dataset samples might increase, we will select required samples again
|
# During Feature creation dataset samples might increase, we will select required samples again
|
||||||
eval_dataset = eval_dataset.select(range(args.max_val_samples))
|
eval_dataset = eval_dataset.select(range(args.max_eval_samples))
|
||||||
|
|
||||||
if args.do_predict:
|
if args.do_predict:
|
||||||
if "test" not in raw_datasets:
|
if "test" not in raw_datasets:
|
||||||
raise ValueError("--do_predict requires a test dataset")
|
raise ValueError("--do_predict requires a test dataset")
|
||||||
test_examples = raw_datasets["test"]
|
predict_examples = raw_datasets["test"]
|
||||||
if args.max_test_samples is not None:
|
if args.max_predict_samples is not None:
|
||||||
# We will select sample from whole data
|
# We will select sample from whole data
|
||||||
test_examples = test_examples.select(range(args.max_test_samples))
|
predict_examples = predict_examples.select(range(args.max_predict_samples))
|
||||||
# Test Feature Creation
|
# Predict Feature Creation
|
||||||
test_dataset = test_examples.map(
|
predict_dataset = predict_examples.map(
|
||||||
prepare_validation_features,
|
prepare_validation_features,
|
||||||
batched=True,
|
batched=True,
|
||||||
num_proc=args.preprocessing_num_workers,
|
num_proc=args.preprocessing_num_workers,
|
||||||
remove_columns=column_names,
|
remove_columns=column_names,
|
||||||
load_from_cache_file=not args.overwrite_cache,
|
load_from_cache_file=not args.overwrite_cache,
|
||||||
)
|
)
|
||||||
if args.max_test_samples is not None:
|
if args.max_predict_samples is not None:
|
||||||
# During Feature creation dataset samples might increase, we will select required samples again
|
# During Feature creation dataset samples might increase, we will select required samples again
|
||||||
test_dataset = test_dataset.select(range(args.max_test_samples))
|
predict_dataset = predict_dataset.select(range(args.max_predict_samples))
|
||||||
|
|
||||||
# Log a few random samples from the training set:
|
# Log a few random samples from the training set:
|
||||||
for index in random.sample(range(len(train_dataset)), 3):
|
for index in random.sample(range(len(train_dataset)), 3):
|
||||||
@@ -539,9 +539,9 @@ def main():
|
|||||||
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
|
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
|
||||||
|
|
||||||
if args.do_predict:
|
if args.do_predict:
|
||||||
test_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
|
predict_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
|
||||||
test_dataloader = DataLoader(
|
predict_dataloader = DataLoader(
|
||||||
test_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
|
predict_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
# Post-processing:
|
# Post-processing:
|
||||||
@@ -737,7 +737,7 @@ def main():
|
|||||||
all_end_top_log_probs = []
|
all_end_top_log_probs = []
|
||||||
all_end_top_index = []
|
all_end_top_index = []
|
||||||
all_cls_logits = []
|
all_cls_logits = []
|
||||||
for step, batch in enumerate(test_dataloader):
|
for step, batch in enumerate(predict_dataloader):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(**batch)
|
outputs = model(**batch)
|
||||||
start_top_log_probs = outputs.start_top_log_probs
|
start_top_log_probs = outputs.start_top_log_probs
|
||||||
@@ -762,10 +762,10 @@ def main():
|
|||||||
max_len = max([x.shape[1] for x in all_end_top_log_probs]) # Get the max_length of the tensor
|
max_len = max([x.shape[1] for x in all_end_top_log_probs]) # Get the max_length of the tensor
|
||||||
|
|
||||||
# concatenate all numpy arrays collected above
|
# concatenate all numpy arrays collected above
|
||||||
start_top_log_probs_concat = create_and_fill_np_array(all_start_top_log_probs, test_dataset, max_len)
|
start_top_log_probs_concat = create_and_fill_np_array(all_start_top_log_probs, predict_dataset, max_len)
|
||||||
start_top_index_concat = create_and_fill_np_array(all_start_top_index, test_dataset, max_len)
|
start_top_index_concat = create_and_fill_np_array(all_start_top_index, predict_dataset, max_len)
|
||||||
end_top_log_probs_concat = create_and_fill_np_array(all_end_top_log_probs, test_dataset, max_len)
|
end_top_log_probs_concat = create_and_fill_np_array(all_end_top_log_probs, predict_dataset, max_len)
|
||||||
end_top_index_concat = create_and_fill_np_array(all_end_top_index, test_dataset, max_len)
|
end_top_index_concat = create_and_fill_np_array(all_end_top_index, predict_dataset, max_len)
|
||||||
all_cls_logits = np.concatenate(all_cls_logits, axis=0)
|
all_cls_logits = np.concatenate(all_cls_logits, axis=0)
|
||||||
|
|
||||||
# delete the list of numpy arrays
|
# delete the list of numpy arrays
|
||||||
@@ -774,7 +774,7 @@ def main():
|
|||||||
del end_top_log_probs
|
del end_top_log_probs
|
||||||
del end_top_index
|
del end_top_index
|
||||||
|
|
||||||
test_dataset.set_format(type=None, columns=list(test_dataset.features.keys()))
|
predict_dataset.set_format(type=None, columns=list(predict_dataset.features.keys()))
|
||||||
outputs_numpy = (
|
outputs_numpy = (
|
||||||
start_top_log_probs_concat,
|
start_top_log_probs_concat,
|
||||||
start_top_index_concat,
|
start_top_index_concat,
|
||||||
@@ -783,9 +783,9 @@ def main():
|
|||||||
cls_logits,
|
cls_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
prediction = post_processing_function(test_examples, test_dataset, outputs_numpy)
|
prediction = post_processing_function(predict_examples, predict_dataset, outputs_numpy)
|
||||||
test_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
|
predict_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
|
||||||
logger.info(f"Test metrics: {test_metric}")
|
logger.info(f"Predict metrics: {predict_metric}")
|
||||||
|
|
||||||
if args.output_dir is not None:
|
if args.output_dir is not None:
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|||||||
@@ -205,20 +205,20 @@ def parse_args():
|
|||||||
"value if set.",
|
"value if set.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_val_samples",
|
"--max_eval_samples",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="For debugging purposes or quicker training, truncate the number of validation examples to this "
|
help="For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set.",
|
"value if set.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
|
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_test_samples",
|
"--max_predict_samples",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="For debugging purposes or quicker training, truncate the number of test examples to this",
|
help="For debugging purposes or quicker training, truncate the number of prediction examples to this",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model_type",
|
"--model_type",
|
||||||
@@ -486,9 +486,9 @@ def main():
|
|||||||
if "validation" not in raw_datasets:
|
if "validation" not in raw_datasets:
|
||||||
raise ValueError("--do_eval requires a validation dataset")
|
raise ValueError("--do_eval requires a validation dataset")
|
||||||
eval_examples = raw_datasets["validation"]
|
eval_examples = raw_datasets["validation"]
|
||||||
if args.max_val_samples is not None:
|
if args.max_eval_samples is not None:
|
||||||
# We will select sample from whole data
|
# We will select sample from whole data
|
||||||
eval_examples = eval_examples.select(range(args.max_val_samples))
|
eval_examples = eval_examples.select(range(args.max_eval_samples))
|
||||||
# Validation Feature Creation
|
# Validation Feature Creation
|
||||||
eval_dataset = eval_examples.map(
|
eval_dataset = eval_examples.map(
|
||||||
prepare_validation_features,
|
prepare_validation_features,
|
||||||
@@ -498,28 +498,28 @@ def main():
|
|||||||
load_from_cache_file=not args.overwrite_cache,
|
load_from_cache_file=not args.overwrite_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.max_val_samples is not None:
|
if args.max_eval_samples is not None:
|
||||||
# During Feature creation dataset samples might increase, we will select required samples again
|
# During Feature creation dataset samples might increase, we will select required samples again
|
||||||
eval_dataset = eval_dataset.select(range(args.max_val_samples))
|
eval_dataset = eval_dataset.select(range(args.max_eval_samples))
|
||||||
|
|
||||||
if args.do_predict:
|
if args.do_predict:
|
||||||
if "test" not in raw_datasets:
|
if "test" not in raw_datasets:
|
||||||
raise ValueError("--do_predict requires a test dataset")
|
raise ValueError("--do_predict requires a test dataset")
|
||||||
test_examples = raw_datasets["test"]
|
predict_examples = raw_datasets["test"]
|
||||||
if args.max_test_samples is not None:
|
if args.max_predict_samples is not None:
|
||||||
# We will select sample from whole data
|
# We will select sample from whole data
|
||||||
test_examples = test_examples.select(range(args.max_test_samples))
|
predict_examples = predict_examples.select(range(args.max_predict_samples))
|
||||||
# Test Feature Creation
|
# Predict Feature Creation
|
||||||
test_dataset = test_examples.map(
|
predict_dataset = predict_examples.map(
|
||||||
prepare_validation_features,
|
prepare_validation_features,
|
||||||
batched=True,
|
batched=True,
|
||||||
num_proc=args.preprocessing_num_workers,
|
num_proc=args.preprocessing_num_workers,
|
||||||
remove_columns=column_names,
|
remove_columns=column_names,
|
||||||
load_from_cache_file=not args.overwrite_cache,
|
load_from_cache_file=not args.overwrite_cache,
|
||||||
)
|
)
|
||||||
if args.max_test_samples is not None:
|
if args.max_predict_samples is not None:
|
||||||
# During Feature creation dataset samples might increase, we will select required samples again
|
# During Feature creation dataset samples might increase, we will select required samples again
|
||||||
test_dataset = test_dataset.select(range(args.max_test_samples))
|
predict_dataset = predict_dataset.select(range(args.max_predict_samples))
|
||||||
|
|
||||||
# Log a few random samples from the training set:
|
# Log a few random samples from the training set:
|
||||||
for index in random.sample(range(len(train_dataset)), 3):
|
for index in random.sample(range(len(train_dataset)), 3):
|
||||||
@@ -544,9 +544,9 @@ def main():
|
|||||||
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
|
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
|
||||||
|
|
||||||
if args.do_predict:
|
if args.do_predict:
|
||||||
test_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
|
predict_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"])
|
||||||
test_dataloader = DataLoader(
|
predict_dataloader = DataLoader(
|
||||||
test_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
|
predict_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
# Post-processing:
|
# Post-processing:
|
||||||
@@ -714,7 +714,7 @@ def main():
|
|||||||
if args.do_predict:
|
if args.do_predict:
|
||||||
all_start_logits = []
|
all_start_logits = []
|
||||||
all_end_logits = []
|
all_end_logits = []
|
||||||
for step, batch in enumerate(test_dataloader):
|
for step, batch in enumerate(predict_dataloader):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(**batch)
|
outputs = model(**batch)
|
||||||
start_logits = outputs.start_logits
|
start_logits = outputs.start_logits
|
||||||
@@ -729,19 +729,19 @@ def main():
|
|||||||
|
|
||||||
max_len = max([x.shape[1] for x in all_start_logits]) # Get the max_length of the tensor
|
max_len = max([x.shape[1] for x in all_start_logits]) # Get the max_length of the tensor
|
||||||
# concatenate the numpy array
|
# concatenate the numpy array
|
||||||
start_logits_concat = create_and_fill_np_array(all_start_logits, test_dataset, max_len)
|
start_logits_concat = create_and_fill_np_array(all_start_logits, predict_dataset, max_len)
|
||||||
end_logits_concat = create_and_fill_np_array(all_end_logits, test_dataset, max_len)
|
end_logits_concat = create_and_fill_np_array(all_end_logits, predict_dataset, max_len)
|
||||||
|
|
||||||
# delete the list of numpy arrays
|
# delete the list of numpy arrays
|
||||||
del all_start_logits
|
del all_start_logits
|
||||||
del all_end_logits
|
del all_end_logits
|
||||||
|
|
||||||
# Now we need to add extra columns which we removed for post processing
|
# Now we need to add extra columns which we removed for post processing
|
||||||
test_dataset.set_format(type=None, columns=list(test_dataset.features.keys()))
|
predict_dataset.set_format(type=None, columns=list(predict_dataset.features.keys()))
|
||||||
outputs_numpy = (start_logits_concat, end_logits_concat)
|
outputs_numpy = (start_logits_concat, end_logits_concat)
|
||||||
prediction = post_processing_function(test_examples, test_dataset, outputs_numpy)
|
prediction = post_processing_function(predict_examples, predict_dataset, outputs_numpy)
|
||||||
eval_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
|
predict_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids)
|
||||||
logger.info(f"Test metrics: {eval_metric}")
|
logger.info(f"Predict metrics: {predict_metric}")
|
||||||
|
|
||||||
if args.output_dir is not None:
|
if args.output_dir is not None:
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|||||||
@@ -66,16 +66,16 @@ class QuestionAnsweringTrainer(Trainer):
|
|||||||
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
|
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
def predict(self, test_dataset, test_examples, ignore_keys=None):
|
def predict(self, predict_dataset, predict_examples, ignore_keys=None):
|
||||||
test_dataloader = self.get_test_dataloader(test_dataset)
|
predict_dataloader = self.get_test_dataloader(predict_dataset)
|
||||||
|
|
||||||
# Temporarily disable metric computation, we will do it in the loop here.
|
# Temporarily disable metric computation, we will do it in the loop here.
|
||||||
compute_metrics = self.compute_metrics
|
compute_metrics = self.compute_metrics
|
||||||
self.compute_metrics = None
|
self.compute_metrics = None
|
||||||
try:
|
try:
|
||||||
output = self.prediction_loop(
|
output = self.prediction_loop(
|
||||||
test_dataloader,
|
predict_dataloader,
|
||||||
description="Evaluation",
|
description="Prediction",
|
||||||
# No point gathering the predictions if there are no metrics, otherwise we defer to
|
# No point gathering the predictions if there are no metrics, otherwise we defer to
|
||||||
# self.args.prediction_loss_only
|
# self.args.prediction_loss_only
|
||||||
prediction_loss_only=True if compute_metrics is None else None,
|
prediction_loss_only=True if compute_metrics is None else None,
|
||||||
@@ -87,7 +87,7 @@ class QuestionAnsweringTrainer(Trainer):
|
|||||||
if self.post_process_function is None or self.compute_metrics is None:
|
if self.post_process_function is None or self.compute_metrics is None:
|
||||||
return output
|
return output
|
||||||
|
|
||||||
eval_preds = self.post_process_function(test_examples, test_dataset, output.predictions, "test")
|
predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict")
|
||||||
metrics = self.compute_metrics(eval_preds)
|
metrics = self.compute_metrics(predictions)
|
||||||
|
|
||||||
return PredictionOutput(predictions=eval_preds.predictions, label_ids=eval_preds.label_ids, metrics=metrics)
|
return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics)
|
||||||
|
|||||||
@@ -178,17 +178,17 @@ class DataTrainingArguments:
|
|||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_val_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_test_samples: Optional[int] = field(
|
max_predict_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -438,8 +438,8 @@ def main():
|
|||||||
if "validation" not in datasets:
|
if "validation" not in datasets:
|
||||||
raise ValueError("--do_eval requires a validation dataset")
|
raise ValueError("--do_eval requires a validation dataset")
|
||||||
eval_dataset = datasets["validation"]
|
eval_dataset = datasets["validation"]
|
||||||
if data_args.max_val_samples is not None:
|
if data_args.max_eval_samples is not None:
|
||||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||||
eval_dataset = eval_dataset.map(
|
eval_dataset = eval_dataset.map(
|
||||||
preprocess_function,
|
preprocess_function,
|
||||||
batched=True,
|
batched=True,
|
||||||
@@ -452,10 +452,10 @@ def main():
|
|||||||
max_target_length = data_args.val_max_target_length
|
max_target_length = data_args.val_max_target_length
|
||||||
if "test" not in datasets:
|
if "test" not in datasets:
|
||||||
raise ValueError("--do_predict requires a test dataset")
|
raise ValueError("--do_predict requires a test dataset")
|
||||||
test_dataset = datasets["test"]
|
predict_dataset = datasets["test"]
|
||||||
if data_args.max_test_samples is not None:
|
if data_args.max_predict_samples is not None:
|
||||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||||
test_dataset = test_dataset.map(
|
predict_dataset = predict_dataset.map(
|
||||||
preprocess_function,
|
preprocess_function,
|
||||||
batched=True,
|
batched=True,
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
@@ -547,37 +547,39 @@ def main():
|
|||||||
metrics = trainer.evaluate(
|
metrics = trainer.evaluate(
|
||||||
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
|
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
|
||||||
)
|
)
|
||||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||||
|
|
||||||
trainer.log_metrics("eval", metrics)
|
trainer.log_metrics("eval", metrics)
|
||||||
trainer.save_metrics("eval", metrics)
|
trainer.save_metrics("eval", metrics)
|
||||||
|
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
logger.info("*** Test ***")
|
logger.info("*** Predict ***")
|
||||||
|
|
||||||
test_results = trainer.predict(
|
predict_results = trainer.predict(
|
||||||
test_dataset,
|
predict_dataset,
|
||||||
metric_key_prefix="test",
|
metric_key_prefix="predict",
|
||||||
max_length=data_args.val_max_target_length,
|
max_length=data_args.val_max_target_length,
|
||||||
num_beams=data_args.num_beams,
|
num_beams=data_args.num_beams,
|
||||||
)
|
)
|
||||||
metrics = test_results.metrics
|
metrics = predict_results.metrics
|
||||||
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
|
max_predict_samples = (
|
||||||
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
|
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
|
||||||
|
)
|
||||||
|
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
|
||||||
|
|
||||||
trainer.log_metrics("test", metrics)
|
trainer.log_metrics("predict", metrics)
|
||||||
trainer.save_metrics("test", metrics)
|
trainer.save_metrics("predict", metrics)
|
||||||
|
|
||||||
if trainer.is_world_process_zero():
|
if trainer.is_world_process_zero():
|
||||||
if training_args.predict_with_generate:
|
if training_args.predict_with_generate:
|
||||||
test_preds = tokenizer.batch_decode(
|
predictions = tokenizer.batch_decode(
|
||||||
test_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||||
)
|
)
|
||||||
test_preds = [pred.strip() for pred in test_preds]
|
predictions = [pred.strip() for pred in predictions]
|
||||||
output_test_preds_file = os.path.join(training_args.output_dir, "test_generations.txt")
|
output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
|
||||||
with open(output_test_preds_file, "w") as writer:
|
with open(output_prediction_file, "w") as writer:
|
||||||
writer.write("\n".join(test_preds))
|
writer.write("\n".join(predictions))
|
||||||
|
|
||||||
if training_args.push_to_hub:
|
if training_args.push_to_hub:
|
||||||
trainer.push_to_hub()
|
trainer.push_to_hub()
|
||||||
|
|||||||
@@ -100,17 +100,17 @@ class DataTrainingArguments:
|
|||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_val_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_test_samples: Optional[int] = field(
|
max_predict_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -390,15 +390,15 @@ def main():
|
|||||||
if "validation" not in datasets and "validation_matched" not in datasets:
|
if "validation" not in datasets and "validation_matched" not in datasets:
|
||||||
raise ValueError("--do_eval requires a validation dataset")
|
raise ValueError("--do_eval requires a validation dataset")
|
||||||
eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
|
eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
|
||||||
if data_args.max_val_samples is not None:
|
if data_args.max_eval_samples is not None:
|
||||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||||
|
|
||||||
if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
|
if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
|
||||||
if "test" not in datasets and "test_matched" not in datasets:
|
if "test" not in datasets and "test_matched" not in datasets:
|
||||||
raise ValueError("--do_predict requires a test dataset")
|
raise ValueError("--do_predict requires a test dataset")
|
||||||
test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]
|
predict_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]
|
||||||
if data_args.max_test_samples is not None:
|
if data_args.max_predict_samples is not None:
|
||||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||||
|
|
||||||
# Log a few random samples from the training set:
|
# Log a few random samples from the training set:
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
@@ -483,32 +483,34 @@ def main():
|
|||||||
for eval_dataset, task in zip(eval_datasets, tasks):
|
for eval_dataset, task in zip(eval_datasets, tasks):
|
||||||
metrics = trainer.evaluate(eval_dataset=eval_dataset)
|
metrics = trainer.evaluate(eval_dataset=eval_dataset)
|
||||||
|
|
||||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
max_eval_samples = (
|
||||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||||
|
)
|
||||||
|
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||||
|
|
||||||
trainer.log_metrics("eval", metrics)
|
trainer.log_metrics("eval", metrics)
|
||||||
trainer.save_metrics("eval", metrics)
|
trainer.save_metrics("eval", metrics)
|
||||||
|
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
logger.info("*** Test ***")
|
logger.info("*** Predict ***")
|
||||||
|
|
||||||
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
||||||
tasks = [data_args.task_name]
|
tasks = [data_args.task_name]
|
||||||
test_datasets = [test_dataset]
|
predict_datasets = [predict_dataset]
|
||||||
if data_args.task_name == "mnli":
|
if data_args.task_name == "mnli":
|
||||||
tasks.append("mnli-mm")
|
tasks.append("mnli-mm")
|
||||||
test_datasets.append(datasets["test_mismatched"])
|
predict_datasets.append(datasets["test_mismatched"])
|
||||||
|
|
||||||
for test_dataset, task in zip(test_datasets, tasks):
|
for predict_dataset, task in zip(predict_datasets, tasks):
|
||||||
# Removing the `label` columns because it contains -1 and Trainer won't like that.
|
# Removing the `label` columns because it contains -1 and Trainer won't like that.
|
||||||
test_dataset.remove_columns_("label")
|
predict_dataset.remove_columns_("label")
|
||||||
predictions = trainer.predict(test_dataset=test_dataset).predictions
|
predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions
|
||||||
predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
|
predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
|
||||||
|
|
||||||
output_test_file = os.path.join(training_args.output_dir, f"test_results_{task}.txt")
|
output_predict_file = os.path.join(training_args.output_dir, f"predict_results_{task}.txt")
|
||||||
if trainer.is_world_process_zero():
|
if trainer.is_world_process_zero():
|
||||||
with open(output_test_file, "w") as writer:
|
with open(output_predict_file, "w") as writer:
|
||||||
logger.info(f"***** Test results {task} *****")
|
logger.info(f"***** Predict results {task} *****")
|
||||||
writer.write("index\tprediction\n")
|
writer.write("index\tprediction\n")
|
||||||
for index, item in enumerate(predictions):
|
for index, item in enumerate(predictions):
|
||||||
if is_regression:
|
if is_regression:
|
||||||
|
|||||||
@@ -84,17 +84,17 @@ class DataTrainingArguments:
|
|||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_val_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_test_samples: Optional[int] = field(
|
max_predict_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -221,8 +221,8 @@ def main():
|
|||||||
label_list = eval_dataset.features["label"].names
|
label_list = eval_dataset.features["label"].names
|
||||||
|
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
test_dataset = load_dataset("xnli", model_args.language, split="test", cache_dir=model_args.cache_dir)
|
predict_dataset = load_dataset("xnli", model_args.language, split="test", cache_dir=model_args.cache_dir)
|
||||||
label_list = test_dataset.features["label"].names
|
label_list = predict_dataset.features["label"].names
|
||||||
|
|
||||||
# Labels
|
# Labels
|
||||||
num_labels = len(label_list)
|
num_labels = len(label_list)
|
||||||
@@ -286,8 +286,8 @@ def main():
|
|||||||
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
||||||
|
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
if data_args.max_val_samples is not None:
|
if data_args.max_eval_samples is not None:
|
||||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||||
eval_dataset = eval_dataset.map(
|
eval_dataset = eval_dataset.map(
|
||||||
preprocess_function,
|
preprocess_function,
|
||||||
batched=True,
|
batched=True,
|
||||||
@@ -295,9 +295,9 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
if data_args.max_test_samples is not None:
|
if data_args.max_predict_samples is not None:
|
||||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||||
test_dataset = test_dataset.map(
|
predict_dataset = predict_dataset.map(
|
||||||
preprocess_function,
|
preprocess_function,
|
||||||
batched=True,
|
batched=True,
|
||||||
load_from_cache_file=not data_args.overwrite_cache,
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
@@ -360,8 +360,8 @@ def main():
|
|||||||
logger.info("*** Evaluate ***")
|
logger.info("*** Evaluate ***")
|
||||||
metrics = trainer.evaluate(eval_dataset=eval_dataset)
|
metrics = trainer.evaluate(eval_dataset=eval_dataset)
|
||||||
|
|
||||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||||
|
|
||||||
trainer.log_metrics("eval", metrics)
|
trainer.log_metrics("eval", metrics)
|
||||||
trainer.save_metrics("eval", metrics)
|
trainer.save_metrics("eval", metrics)
|
||||||
@@ -369,18 +369,20 @@ def main():
|
|||||||
# Prediction
|
# Prediction
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
logger.info("*** Predict ***")
|
logger.info("*** Predict ***")
|
||||||
predictions, labels, metrics = trainer.predict(test_dataset)
|
predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict")
|
||||||
|
|
||||||
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
|
max_predict_samples = (
|
||||||
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
|
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
|
||||||
|
)
|
||||||
|
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
|
||||||
|
|
||||||
trainer.log_metrics("test", metrics)
|
trainer.log_metrics("predict", metrics)
|
||||||
trainer.save_metrics("test", metrics)
|
trainer.save_metrics("predict", metrics)
|
||||||
|
|
||||||
predictions = np.argmax(predictions, axis=1)
|
predictions = np.argmax(predictions, axis=1)
|
||||||
output_test_file = os.path.join(training_args.output_dir, "test_predictions.txt")
|
output_predict_file = os.path.join(training_args.output_dir, "predictions.txt")
|
||||||
if trainer.is_world_process_zero():
|
if trainer.is_world_process_zero():
|
||||||
with open(output_test_file, "w") as writer:
|
with open(output_predict_file, "w") as writer:
|
||||||
writer.write("index\tprediction\n")
|
writer.write("index\tprediction\n")
|
||||||
for index, item in enumerate(predictions):
|
for index, item in enumerate(predictions):
|
||||||
item = label_list[item]
|
item = label_list[item]
|
||||||
|
|||||||
@@ -128,17 +128,17 @@ class DataTrainingArguments:
|
|||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_val_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_test_samples: Optional[int] = field(
|
max_predict_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -363,8 +363,8 @@ def main():
|
|||||||
if "validation" not in datasets:
|
if "validation" not in datasets:
|
||||||
raise ValueError("--do_eval requires a validation dataset")
|
raise ValueError("--do_eval requires a validation dataset")
|
||||||
eval_dataset = datasets["validation"]
|
eval_dataset = datasets["validation"]
|
||||||
if data_args.max_val_samples is not None:
|
if data_args.max_eval_samples is not None:
|
||||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||||
eval_dataset = eval_dataset.map(
|
eval_dataset = eval_dataset.map(
|
||||||
tokenize_and_align_labels,
|
tokenize_and_align_labels,
|
||||||
batched=True,
|
batched=True,
|
||||||
@@ -375,10 +375,10 @@ def main():
|
|||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
if "test" not in datasets:
|
if "test" not in datasets:
|
||||||
raise ValueError("--do_predict requires a test dataset")
|
raise ValueError("--do_predict requires a test dataset")
|
||||||
test_dataset = datasets["test"]
|
predict_dataset = datasets["test"]
|
||||||
if data_args.max_test_samples is not None:
|
if data_args.max_predict_samples is not None:
|
||||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||||
test_dataset = test_dataset.map(
|
predict_dataset = predict_dataset.map(
|
||||||
tokenize_and_align_labels,
|
tokenize_and_align_labels,
|
||||||
batched=True,
|
batched=True,
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
@@ -462,8 +462,8 @@ def main():
|
|||||||
|
|
||||||
metrics = trainer.evaluate()
|
metrics = trainer.evaluate()
|
||||||
|
|
||||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||||
|
|
||||||
trainer.log_metrics("eval", metrics)
|
trainer.log_metrics("eval", metrics)
|
||||||
trainer.save_metrics("eval", metrics)
|
trainer.save_metrics("eval", metrics)
|
||||||
@@ -472,7 +472,7 @@ def main():
|
|||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
logger.info("*** Predict ***")
|
logger.info("*** Predict ***")
|
||||||
|
|
||||||
predictions, labels, metrics = trainer.predict(test_dataset)
|
predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict")
|
||||||
predictions = np.argmax(predictions, axis=2)
|
predictions = np.argmax(predictions, axis=2)
|
||||||
|
|
||||||
# Remove ignored index (special tokens)
|
# Remove ignored index (special tokens)
|
||||||
@@ -481,13 +481,13 @@ def main():
|
|||||||
for prediction, label in zip(predictions, labels)
|
for prediction, label in zip(predictions, labels)
|
||||||
]
|
]
|
||||||
|
|
||||||
trainer.log_metrics("test", metrics)
|
trainer.log_metrics("predict", metrics)
|
||||||
trainer.save_metrics("test", metrics)
|
trainer.save_metrics("predict", metrics)
|
||||||
|
|
||||||
# Save predictions
|
# Save predictions
|
||||||
output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt")
|
output_predictions_file = os.path.join(training_args.output_dir, "predictions.txt")
|
||||||
if trainer.is_world_process_zero():
|
if trainer.is_world_process_zero():
|
||||||
with open(output_test_predictions_file, "w") as writer:
|
with open(output_predictions_file, "w") as writer:
|
||||||
for prediction in true_predictions:
|
for prediction in true_predictions:
|
||||||
writer.write(" ".join(prediction) + "\n")
|
writer.write(" ".join(prediction) + "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -167,17 +167,17 @@ class DataTrainingArguments:
|
|||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_val_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_test_samples: Optional[int] = field(
|
max_predict_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -432,8 +432,8 @@ def main():
|
|||||||
if "validation" not in datasets:
|
if "validation" not in datasets:
|
||||||
raise ValueError("--do_eval requires a validation dataset")
|
raise ValueError("--do_eval requires a validation dataset")
|
||||||
eval_dataset = datasets["validation"]
|
eval_dataset = datasets["validation"]
|
||||||
if data_args.max_val_samples is not None:
|
if data_args.max_eval_samples is not None:
|
||||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||||
eval_dataset = eval_dataset.map(
|
eval_dataset = eval_dataset.map(
|
||||||
preprocess_function,
|
preprocess_function,
|
||||||
batched=True,
|
batched=True,
|
||||||
@@ -446,10 +446,10 @@ def main():
|
|||||||
max_target_length = data_args.val_max_target_length
|
max_target_length = data_args.val_max_target_length
|
||||||
if "test" not in datasets:
|
if "test" not in datasets:
|
||||||
raise ValueError("--do_predict requires a test dataset")
|
raise ValueError("--do_predict requires a test dataset")
|
||||||
test_dataset = datasets["test"]
|
predict_dataset = datasets["test"]
|
||||||
if data_args.max_test_samples is not None:
|
if data_args.max_predict_samples is not None:
|
||||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||||
test_dataset = test_dataset.map(
|
predict_dataset = predict_dataset.map(
|
||||||
preprocess_function,
|
preprocess_function,
|
||||||
batched=True,
|
batched=True,
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
@@ -539,37 +539,39 @@ def main():
|
|||||||
metrics = trainer.evaluate(
|
metrics = trainer.evaluate(
|
||||||
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
|
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
|
||||||
)
|
)
|
||||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||||
|
|
||||||
trainer.log_metrics("eval", metrics)
|
trainer.log_metrics("eval", metrics)
|
||||||
trainer.save_metrics("eval", metrics)
|
trainer.save_metrics("eval", metrics)
|
||||||
|
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
logger.info("*** Test ***")
|
logger.info("*** Predict ***")
|
||||||
|
|
||||||
test_results = trainer.predict(
|
predict_results = trainer.predict(
|
||||||
test_dataset,
|
predict_dataset,
|
||||||
metric_key_prefix="test",
|
metric_key_prefix="predict",
|
||||||
max_length=data_args.val_max_target_length,
|
max_length=data_args.val_max_target_length,
|
||||||
num_beams=data_args.num_beams,
|
num_beams=data_args.num_beams,
|
||||||
)
|
)
|
||||||
metrics = test_results.metrics
|
metrics = predict_results.metrics
|
||||||
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
|
max_predict_samples = (
|
||||||
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
|
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
|
||||||
|
)
|
||||||
|
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
|
||||||
|
|
||||||
trainer.log_metrics("test", metrics)
|
trainer.log_metrics("predict", metrics)
|
||||||
trainer.save_metrics("test", metrics)
|
trainer.save_metrics("predict", metrics)
|
||||||
|
|
||||||
if trainer.is_world_process_zero():
|
if trainer.is_world_process_zero():
|
||||||
if training_args.predict_with_generate:
|
if training_args.predict_with_generate:
|
||||||
test_preds = tokenizer.batch_decode(
|
predictions = tokenizer.batch_decode(
|
||||||
test_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||||
)
|
)
|
||||||
test_preds = [pred.strip() for pred in test_preds]
|
predictions = [pred.strip() for pred in predictions]
|
||||||
output_test_preds_file = os.path.join(training_args.output_dir, "test_generations.txt")
|
output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
|
||||||
with open(output_test_preds_file, "w") as writer:
|
with open(output_prediction_file, "w") as writer:
|
||||||
writer.write("\n".join(test_preds))
|
writer.write("\n".join(predictions))
|
||||||
|
|
||||||
if training_args.push_to_hub:
|
if training_args.push_to_hub:
|
||||||
trainer.push_to_hub()
|
trainer.push_to_hub()
|
||||||
|
|||||||
@@ -164,17 +164,17 @@ class DataTrainingArguments:
|
|||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_val_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_test_samples: Optional[int] = field(
|
max_predict_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
"help": "For debugging purposes or quicker training, truncate the number of predict examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -468,13 +468,13 @@ def main():
|
|||||||
|
|
||||||
if "validation" in datasets:
|
if "validation" in datasets:
|
||||||
eval_dataset = datasets["validation"]
|
eval_dataset = datasets["validation"]
|
||||||
if data_args.max_val_samples is not None:
|
if data_args.max_eval_samples is not None:
|
||||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||||
|
|
||||||
if "test" in datasets:
|
if "test" in datasets:
|
||||||
test_dataset = datasets["test"]
|
predict_dataset = datasets["test"]
|
||||||
if data_args.max_test_samples is not None:
|
if data_args.max_predict_samples is not None:
|
||||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
@@ -513,15 +513,15 @@ def main():
|
|||||||
|
|
||||||
# region Prediction
|
# region Prediction
|
||||||
if "test" in datasets:
|
if "test" in datasets:
|
||||||
logger.info("Doing predictions on test dataset...")
|
logger.info("Doing predictions on Predict dataset...")
|
||||||
|
|
||||||
test_dataset = DataSequence(
|
predict_dataset = DataSequence(
|
||||||
test_dataset, non_label_column_names, batch_size=training_args.per_device_eval_batch_size, labels=False
|
predict_dataset, non_label_column_names, batch_size=training_args.per_device_eval_batch_size, labels=False
|
||||||
)
|
)
|
||||||
predictions = model.predict(test_dataset)["logits"]
|
predictions = model.predict(predict_dataset)["logits"]
|
||||||
predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
|
predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
|
||||||
output_test_file = os.path.join(training_args.output_dir, "test_results.txt")
|
output_predict_file = os.path.join(training_args.output_dir, "predict_results.txt")
|
||||||
with open(output_test_file, "w") as writer:
|
with open(output_predict_file, "w") as writer:
|
||||||
writer.write("index\tprediction\n")
|
writer.write("index\tprediction\n")
|
||||||
for index, item in enumerate(predictions):
|
for index, item in enumerate(predictions):
|
||||||
if is_regression:
|
if is_regression:
|
||||||
@@ -529,7 +529,7 @@ def main():
|
|||||||
else:
|
else:
|
||||||
item = model.config.id2label[item]
|
item = model.config.id2label[item]
|
||||||
writer.write(f"{index}\t{item}\n")
|
writer.write(f"{index}\t{item}\n")
|
||||||
logger.info(f"Wrote predictions to {output_test_file}!")
|
logger.info(f"Wrote predictions to {output_predict_file}!")
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -157,17 +157,17 @@ class DataTrainingArguments:
|
|||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_val_samples: Optional[int] = field(
|
max_eval_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
max_test_samples: Optional[int] = field(
|
max_predict_samples: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
|
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -379,8 +379,8 @@ def main():
|
|||||||
raise ValueError("--do_eval requires a validation dataset")
|
raise ValueError("--do_eval requires a validation dataset")
|
||||||
eval_dataset = datasets["validation"]
|
eval_dataset = datasets["validation"]
|
||||||
# Selecting samples from dataset
|
# Selecting samples from dataset
|
||||||
if data_args.max_val_samples is not None:
|
if data_args.max_eval_samples is not None:
|
||||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
|
||||||
# tokenize validation dataset
|
# tokenize validation dataset
|
||||||
eval_dataset = eval_dataset.map(
|
eval_dataset = eval_dataset.map(
|
||||||
tokenize_function,
|
tokenize_function,
|
||||||
@@ -393,12 +393,12 @@ def main():
|
|||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
if "test" not in datasets:
|
if "test" not in datasets:
|
||||||
raise ValueError("--do_predict requires a test dataset")
|
raise ValueError("--do_predict requires a test dataset")
|
||||||
test_dataset = datasets["test"]
|
predict_dataset = datasets["test"]
|
||||||
# Selecting samples from dataset
|
# Selecting samples from dataset
|
||||||
if data_args.max_test_samples is not None:
|
if data_args.max_predict_samples is not None:
|
||||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
|
||||||
# tokenize test dataset
|
# tokenize predict dataset
|
||||||
test_dataset = test_dataset.map(
|
predict_dataset = predict_dataset.map(
|
||||||
tokenize_function,
|
tokenize_function,
|
||||||
batched=True,
|
batched=True,
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
@@ -455,8 +455,8 @@ def main():
|
|||||||
|
|
||||||
metrics = trainer.evaluate()
|
metrics = trainer.evaluate()
|
||||||
|
|
||||||
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||||
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||||
|
|
||||||
trainer.log_metrics("eval", metrics)
|
trainer.log_metrics("eval", metrics)
|
||||||
trainer.save_metrics("eval", metrics)
|
trainer.save_metrics("eval", metrics)
|
||||||
@@ -464,13 +464,13 @@ def main():
|
|||||||
# Prediction
|
# Prediction
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
logger.info("*** Predict ***")
|
logger.info("*** Predict ***")
|
||||||
predictions, labels, metrics = trainer.predict(test_dataset)
|
predictions, labels, metrics = trainer.predict(predict_dataset)
|
||||||
|
|
||||||
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
|
max_predict_samples = data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
|
||||||
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
|
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
|
||||||
|
|
||||||
trainer.log_metrics("test", metrics)
|
trainer.log_metrics("predict", metrics)
|
||||||
trainer.save_metrics("test", metrics)
|
trainer.save_metrics("predict", metrics)
|
||||||
|
|
||||||
# write custom code for saving predictions according to task
|
# write custom code for saving predictions according to task
|
||||||
|
|
||||||
|
|||||||
@@ -578,7 +578,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
|||||||
args.extend(
|
args.extend(
|
||||||
"""
|
"""
|
||||||
--do_eval
|
--do_eval
|
||||||
--max_val_samples 100
|
--max_eval_samples 100
|
||||||
--per_device_eval_batch_size 2
|
--per_device_eval_batch_size 2
|
||||||
""".split()
|
""".split()
|
||||||
)
|
)
|
||||||
@@ -620,7 +620,7 @@ class TestDeepSpeedWithLauncher(TestCasePlus):
|
|||||||
--do_train
|
--do_train
|
||||||
--do_eval
|
--do_eval
|
||||||
--max_train_samples 10
|
--max_train_samples 10
|
||||||
--max_val_samples 10
|
--max_eval_samples 10
|
||||||
--per_device_train_batch_size 5
|
--per_device_train_batch_size 5
|
||||||
--per_device_eval_batch_size 5
|
--per_device_eval_batch_size 5
|
||||||
--num_train_epochs 1
|
--num_train_epochs 1
|
||||||
|
|||||||
@@ -191,7 +191,7 @@ class TestTrainerExt(TestCasePlus):
|
|||||||
--output_dir {output_dir}
|
--output_dir {output_dir}
|
||||||
--overwrite_output_dir
|
--overwrite_output_dir
|
||||||
--max_train_samples 8
|
--max_train_samples 8
|
||||||
--max_val_samples 8
|
--max_eval_samples 8
|
||||||
--max_source_length {max_len}
|
--max_source_length {max_len}
|
||||||
--max_target_length {max_len}
|
--max_target_length {max_len}
|
||||||
--val_max_target_length {max_len}
|
--val_max_target_length {max_len}
|
||||||
|
|||||||
Reference in New Issue
Block a user