From 7296f1010b6faaf3b1fb409bc5a9ebadcea51973 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Thu, 12 Dec 2019 13:01:04 -0500 Subject: [PATCH] Cleanup squad and add allow train_file and predict_file usage --- examples/run_squad.py | 22 ++++++++++++++-------- transformers/data/processors/squad.py | 6 ++++++ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/examples/run_squad.py b/examples/run_squad.py index 79c8537a4b..117b86e32c 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -337,7 +337,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal else: logger.info("Creating features from dataset file at %s", input_dir) - if not args.data_dir: + if not args.data_dir and ((evaluate and not args.predict_file) or (not evaluate and not args.train_file)): try: import tensorflow_datasets as tfds except ImportError: @@ -350,7 +350,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal examples = SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=evaluate) else: processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor() - examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir) + + if evaluate: + examples = processor.get_dev_examples(args.data_dir, filename=args.predict_file) + else: + examples = processor.get_train_examples(args.data_dir, filename=args.train_file) features, dataset = squad_convert_examples_to_features( examples=examples, @@ -387,7 +391,14 @@ def main(): ## Other parameters parser.add_argument("--data_dir", default=None, type=str, - help="The input data dir. Should contain the .json files for the task. If not specified, will run with tensorflow_datasets.") + help="The input data dir. Should contain the .json files for the task." + + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.") + parser.add_argument("--train_file", default=None, type=str, + help="The input training file. If a data dir is specified, will look for the file there" + + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.") + parser.add_argument("--predict_file", default=None, type=str, + help="The input evaluation file. If a data dir is specified, will look for the file there" + + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.") parser.add_argument("--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name") parser.add_argument("--tokenizer_name", default="", type=str, @@ -472,11 +483,6 @@ def main(): parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") args = parser.parse_args() - args.predict_file = os.path.join(args.output_dir, 'predictions_{}_{}.txt'.format( - list(filter(None, args.model_name_or_path.split('/'))).pop(), - str(args.max_seq_length)) - ) - if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) diff --git a/transformers/data/processors/squad.py b/transformers/data/processors/squad.py index 3d7f832540..9bc4375684 100644 --- a/transformers/data/processors/squad.py +++ b/transformers/data/processors/squad.py @@ -373,6 +373,9 @@ class SquadProcessor(DataProcessor): which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively. """ + if data_dir is None: + data_dir = "" + if self.train_file is None: raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor") @@ -389,6 +392,9 @@ class SquadProcessor(DataProcessor): filename: None by default, specify this if the evaluation file has a different name than the original one which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively. """ + if data_dir is None: + data_dir = "" + if self.dev_file is None: raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor")