From 6f52fce673288541339b18a4d293faadd2be2c1d Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Tue, 9 Mar 2021 19:13:45 -0800 Subject: [PATCH] Fixes an issue in `text-classification` where MNLI eval/test datasets are not being preprocessed. (#10621) * Fix MNLI tests * Linter fix --- examples/text-classification/run_glue.py | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/examples/text-classification/run_glue.py b/examples/text-classification/run_glue.py index 617f67232b..0c20feaf0b 100755 --- a/examples/text-classification/run_glue.py +++ b/examples/text-classification/run_glue.py @@ -374,17 +374,13 @@ def main(): result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] return result + datasets = datasets.map(preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache) if training_args.do_train: if "train" not in datasets: raise ValueError("--do_train requires a train dataset") train_dataset = datasets["train"] if data_args.max_train_samples is not None: train_dataset = train_dataset.select(range(data_args.max_train_samples)) - train_dataset = train_dataset.map( - preprocess_function, - batched=True, - load_from_cache_file=not data_args.overwrite_cache, - ) if training_args.do_eval: if "validation" not in datasets and "validation_matched" not in datasets: @@ -392,11 +388,6 @@ def main(): eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"] if data_args.max_val_samples is not None: eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) - eval_dataset = eval_dataset.map( - preprocess_function, - batched=True, - load_from_cache_file=not data_args.overwrite_cache, - ) if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None: if "test" not in datasets and "test_matched" not in datasets: @@ -404,15 +395,11 @@ def main(): test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"] if data_args.max_test_samples is not None: test_dataset = test_dataset.select(range(data_args.max_test_samples)) - test_dataset = test_dataset.map( - preprocess_function, - batched=True, - load_from_cache_file=not data_args.overwrite_cache, - ) # Log a few random samples from the training set: - for index in random.sample(range(len(train_dataset)), 3): - logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + if training_args.do_train: + for index in random.sample(range(len(train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") # Get the metric function if data_args.task_name is not None: @@ -447,7 +434,7 @@ def main(): trainer = Trainer( model=model, args=training_args, - train_dataset=train_dataset, + train_dataset=train_dataset if training_args.do_train else None, eval_dataset=eval_dataset if training_args.do_eval else None, compute_metrics=compute_metrics, tokenizer=tokenizer,