Fixes an issue in text-classification where MNLI eval/test datasets are not being preprocessed. (#10621)
* Fix MNLI tests * Linter fix
This commit is contained in:
@@ -374,17 +374,13 @@ def main():
|
||||
result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
|
||||
return result
|
||||
|
||||
datasets = datasets.map(preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache)
|
||||
if training_args.do_train:
|
||||
if "train" not in datasets:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
train_dataset = datasets["train"]
|
||||
if data_args.max_train_samples is not None:
|
||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||
train_dataset = train_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
if training_args.do_eval:
|
||||
if "validation" not in datasets and "validation_matched" not in datasets:
|
||||
@@ -392,11 +388,6 @@ def main():
|
||||
eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
|
||||
if data_args.max_val_samples is not None:
|
||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||
eval_dataset = eval_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
|
||||
if "test" not in datasets and "test_matched" not in datasets:
|
||||
@@ -404,15 +395,11 @@ def main():
|
||||
test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]
|
||||
if data_args.max_test_samples is not None:
|
||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
||||
test_dataset = test_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
# Log a few random samples from the training set:
|
||||
for index in random.sample(range(len(train_dataset)), 3):
|
||||
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
||||
if training_args.do_train:
|
||||
for index in random.sample(range(len(train_dataset)), 3):
|
||||
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
||||
|
||||
# Get the metric function
|
||||
if data_args.task_name is not None:
|
||||
@@ -447,7 +434,7 @@ def main():
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
train_dataset=train_dataset if training_args.do_train else None,
|
||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||
compute_metrics=compute_metrics,
|
||||
tokenizer=tokenizer,
|
||||
|
||||
Reference in New Issue
Block a user