From f2ffcaf49f6baee9e18e5f59da28ef5f0866066d Mon Sep 17 00:00:00 2001 From: Tommy Chiang Date: Mon, 10 May 2021 03:42:38 +0800 Subject: [PATCH] [Examples] Check key exists in datasets first (#11503) --- examples/pytorch/multiple-choice/run_swag.py | 2 +- examples/pytorch/summarization/run_summarization.py | 2 +- examples/pytorch/translation/run_translation.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/pytorch/multiple-choice/run_swag.py b/examples/pytorch/multiple-choice/run_swag.py index 2ee7ad7356..e0d9e0571e 100755 --- a/examples/pytorch/multiple-choice/run_swag.py +++ b/examples/pytorch/multiple-choice/run_swag.py @@ -347,9 +347,9 @@ def main(): return {k: [v[i : i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()} if training_args.do_train: - train_dataset = datasets["train"] if "train" not in datasets: raise ValueError("--do_train requires a train dataset") + train_dataset = datasets["train"] if data_args.max_train_samples is not None: train_dataset = train_dataset.select(range(data_args.max_train_samples)) train_dataset = train_dataset.map( diff --git a/examples/pytorch/summarization/run_summarization.py b/examples/pytorch/summarization/run_summarization.py index c310cbd4f4..d049482ca8 100755 --- a/examples/pytorch/summarization/run_summarization.py +++ b/examples/pytorch/summarization/run_summarization.py @@ -422,9 +422,9 @@ def main(): return model_inputs if training_args.do_train: - train_dataset = datasets["train"] if "train" not in datasets: raise ValueError("--do_train requires a train dataset") + train_dataset = datasets["train"] if data_args.max_train_samples is not None: train_dataset = train_dataset.select(range(data_args.max_train_samples)) train_dataset = train_dataset.map( diff --git a/examples/pytorch/translation/run_translation.py b/examples/pytorch/translation/run_translation.py index 56503f98ef..c6d83b30a1 100755 --- a/examples/pytorch/translation/run_translation.py +++ b/examples/pytorch/translation/run_translation.py @@ -416,9 +416,9 @@ def main(): return model_inputs if training_args.do_train: - train_dataset = datasets["train"] if "train" not in datasets: raise ValueError("--do_train requires a train dataset") + train_dataset = datasets["train"] if data_args.max_train_samples is not None: train_dataset = train_dataset.select(range(data_args.max_train_samples)) train_dataset = train_dataset.map(