added max_sample args and metrics changes (#10602)
This commit is contained in:
@@ -144,6 +144,20 @@ class DataTrainingArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||||
)
|
)
|
||||||
|
max_train_samples: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||||
|
"value if set."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
max_val_samples: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
|
||||||
|
"value if set."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
||||||
@@ -317,13 +331,37 @@ def main():
|
|||||||
def tokenize_function(examples):
|
def tokenize_function(examples):
|
||||||
return tokenizer(examples[text_column_name], padding="max_length", truncation=True)
|
return tokenizer(examples[text_column_name], padding="max_length", truncation=True)
|
||||||
|
|
||||||
tokenized_datasets = datasets.map(
|
if training_args.do_train:
|
||||||
tokenize_function,
|
if "train" not in datasets:
|
||||||
batched=True,
|
raise ValueError("--do_train requires a train dataset")
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
train_dataset = datasets["train"]
|
||||||
remove_columns=[text_column_name],
|
if data_args.max_train_samples is not None:
|
||||||
load_from_cache_file=not data_args.overwrite_cache,
|
# Select Sample from Dataset
|
||||||
)
|
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||||
|
# tokenize train dataset in batch
|
||||||
|
train_dataset = train_dataset.map(
|
||||||
|
tokenize_function,
|
||||||
|
batched=True,
|
||||||
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
|
remove_columns=[text_column_name],
|
||||||
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
if training_args.do_eval:
|
||||||
|
if "validation" not in datasets:
|
||||||
|
raise ValueError("--do_eval requires a validation dataset")
|
||||||
|
eval_dataset = datasets["validation"]
|
||||||
|
# Selecting samples from dataset
|
||||||
|
if data_args.max_val_samples is not None:
|
||||||
|
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||||
|
# tokenize validation dataset
|
||||||
|
eval_dataset = eval_dataset.map(
|
||||||
|
tokenize_function,
|
||||||
|
batched=True,
|
||||||
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
|
remove_columns=[text_column_name],
|
||||||
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
|
)
|
||||||
|
|
||||||
# Data collator
|
# Data collator
|
||||||
data_collator=default_data_collator if not training_args.fp16 else DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
data_collator=default_data_collator if not training_args.fp16 else DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
||||||
@@ -332,8 +370,8 @@ def main():
|
|||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
|
train_dataset=train_dataset if training_args.do_train else None,
|
||||||
eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
|
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
)
|
)
|
||||||
@@ -358,33 +396,27 @@ def main():
|
|||||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||||
trainer.save_model() # Saves the tokenizer too for easy upload
|
trainer.save_model() # Saves the tokenizer too for easy upload
|
||||||
|
|
||||||
output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
|
metrics = train_result.metrics
|
||||||
if trainer.is_world_process_zero():
|
max_train_samples = (
|
||||||
with open(output_train_file, "w") as writer:
|
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
||||||
logger.info("***** Train results *****")
|
)
|
||||||
for key, value in sorted(train_result.metrics.items()):
|
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||||
logger.info(f" {key} = {value}")
|
|
||||||
writer.write(f"{key} = {value}\n")
|
|
||||||
|
|
||||||
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
|
trainer.log_metrics("train", metrics)
|
||||||
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
|
trainer.save_metrics("train", metrics)
|
||||||
|
trainer.save_state()
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation
|
||||||
results = {}
|
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
logger.info("*** Evaluate ***")
|
logger.info("*** Evaluate ***")
|
||||||
|
|
||||||
results = trainer.evaluate()
|
metrics = trainer.evaluate()
|
||||||
|
|
||||||
output_eval_file = os.path.join(training_args.output_dir, "eval_results_{{cookiecutter.example_shortcut}}.txt")
|
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
|
||||||
if trainer.is_world_process_zero():
|
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
|
||||||
with open(output_eval_file, "w") as writer:
|
|
||||||
logger.info("***** Eval results *****")
|
|
||||||
for key, value in sorted(results.items()):
|
|
||||||
logger.info(f" {key} = {value}")
|
|
||||||
writer.write(f"{key} = {value}\n")
|
|
||||||
|
|
||||||
return results
|
trainer.log_metrics("eval", metrics)
|
||||||
|
trainer.save_metrics("eval", metrics)
|
||||||
|
|
||||||
|
|
||||||
def _mp_fn(index):
|
def _mp_fn(index):
|
||||||
|
|||||||
Reference in New Issue
Block a user