Fixes an issue in text-classification where MNLI eval/test datasets are not being preprocessed. (#10621)
* Fix MNLI tests * Linter fix
This commit is contained in:
@@ -374,17 +374,13 @@ def main():
|
|||||||
result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
|
result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
datasets = datasets.map(preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache)
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
if "train" not in datasets:
|
if "train" not in datasets:
|
||||||
raise ValueError("--do_train requires a train dataset")
|
raise ValueError("--do_train requires a train dataset")
|
||||||
train_dataset = datasets["train"]
|
train_dataset = datasets["train"]
|
||||||
if data_args.max_train_samples is not None:
|
if data_args.max_train_samples is not None:
|
||||||
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
train_dataset = train_dataset.select(range(data_args.max_train_samples))
|
||||||
train_dataset = train_dataset.map(
|
|
||||||
preprocess_function,
|
|
||||||
batched=True,
|
|
||||||
load_from_cache_file=not data_args.overwrite_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
if "validation" not in datasets and "validation_matched" not in datasets:
|
if "validation" not in datasets and "validation_matched" not in datasets:
|
||||||
@@ -392,11 +388,6 @@ def main():
|
|||||||
eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
|
eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
|
||||||
if data_args.max_val_samples is not None:
|
if data_args.max_val_samples is not None:
|
||||||
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
|
||||||
eval_dataset = eval_dataset.map(
|
|
||||||
preprocess_function,
|
|
||||||
batched=True,
|
|
||||||
load_from_cache_file=not data_args.overwrite_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
|
if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
|
||||||
if "test" not in datasets and "test_matched" not in datasets:
|
if "test" not in datasets and "test_matched" not in datasets:
|
||||||
@@ -404,13 +395,9 @@ def main():
|
|||||||
test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]
|
test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]
|
||||||
if data_args.max_test_samples is not None:
|
if data_args.max_test_samples is not None:
|
||||||
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
test_dataset = test_dataset.select(range(data_args.max_test_samples))
|
||||||
test_dataset = test_dataset.map(
|
|
||||||
preprocess_function,
|
|
||||||
batched=True,
|
|
||||||
load_from_cache_file=not data_args.overwrite_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Log a few random samples from the training set:
|
# Log a few random samples from the training set:
|
||||||
|
if training_args.do_train:
|
||||||
for index in random.sample(range(len(train_dataset)), 3):
|
for index in random.sample(range(len(train_dataset)), 3):
|
||||||
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
|
||||||
|
|
||||||
@@ -447,7 +434,7 @@ def main():
|
|||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset if training_args.do_train else None,
|
||||||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||||
compute_metrics=compute_metrics,
|
compute_metrics=compute_metrics,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
|||||||
Reference in New Issue
Block a user