From fd338abdeba25cb40b27650ba2203ac6789d2776 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Tue, 6 Apr 2021 19:54:13 -0400 Subject: [PATCH] Style --- .../run_qa_beam_search_no_trainer.py | 17 +++++------------ .../question-answering/run_qa_no_trainer.py | 10 +++------- 2 files changed, 8 insertions(+), 19 deletions(-) diff --git a/examples/question-answering/run_qa_beam_search_no_trainer.py b/examples/question-answering/run_qa_beam_search_no_trainer.py index bff8cbcd72..15a6269eb1 100644 --- a/examples/question-answering/run_qa_beam_search_no_trainer.py +++ b/examples/question-answering/run_qa_beam_search_no_trainer.py @@ -76,9 +76,7 @@ def parse_args(): parser.add_argument( "--preprocessing_num_workers", type=int, default=4, help="A csv or a json file containing the training data." ) - parser.add_argument( - "--do_predict", action="store_true", help="Eval the question answering model" - ) + parser.add_argument("--do_predict", action="store_true", help="Eval the question answering model") parser.add_argument( "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." ) @@ -284,7 +282,7 @@ def main(): # Preprocessing the datasets. # Preprocessing is slighlty different for training and evaluation. column_names = raw_datasets["train"].column_names - + question_column_name = "question" if "question" in column_names else column_names[0] context_column_name = "context" if "context" in column_names else column_names[1] answer_column_name = "answers" if "answers" in column_names else column_names[2] @@ -396,7 +394,6 @@ def main(): return tokenized_examples - if "train" not in raw_datasets: raise ValueError("--do_train requires a train dataset") train_dataset = raw_datasets["train"] @@ -481,7 +478,6 @@ def main(): return tokenized_examples - if "validation" not in raw_datasets: raise ValueError("--do_eval requires a validation dataset") eval_examples = raw_datasets["validation"] @@ -539,11 +535,8 @@ def main(): train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size ) - eval_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"]) - eval_dataloader = DataLoader( - eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size - ) + eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size) if args.do_predict: test_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"]) @@ -605,8 +598,8 @@ def main(): if step + batch_size < len(dataset): logits_concat[step : step + batch_size, :cols] = output_logit else: - logits_concat[step:, :cols] = output_logit[:len(dataset) - step] - + logits_concat[step:, :cols] = output_logit[: len(dataset) - step] + step += batch_size return logits_concat diff --git a/examples/question-answering/run_qa_no_trainer.py b/examples/question-answering/run_qa_no_trainer.py index 8ea336dda0..e8e4e3a33a 100755 --- a/examples/question-answering/run_qa_no_trainer.py +++ b/examples/question-answering/run_qa_no_trainer.py @@ -81,9 +81,7 @@ def parse_args(): parser.add_argument( "--preprocessing_num_workers", type=int, default=4, help="A csv or a json file containing the training data." ) - parser.add_argument( - "--do_predict", action="store_true", help="Eval the question answering model" - ) + parser.add_argument("--do_predict", action="store_true", help="Eval the question answering model") parser.add_argument( "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." ) @@ -543,9 +541,7 @@ def main(): ) eval_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"]) - eval_dataloader = DataLoader( - eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size - ) + eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size) if args.do_predict: test_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"]) @@ -607,7 +603,7 @@ def main(): if step + batch_size < len(dataset): logits_concat[step : step + batch_size, :cols] = output_logit else: - logits_concat[step:, :cols] = output_logit[:len(dataset) - step] + logits_concat[step:, :cols] = output_logit[: len(dataset) - step] step += batch_size