From 3a9a9f78614050896356a9a30e9529c502b56d96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 5 Dec 2019 19:09:47 +0100 Subject: [PATCH] default output dir to documents dir --- examples/summarization/run_summarization.py | 11 ++++++----- examples/summarization/utils_summarization.py | 2 ++ 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/examples/summarization/run_summarization.py b/examples/summarization/run_summarization.py index e3b974acd9..bbc79227ca 100644 --- a/examples/summarization/run_summarization.py +++ b/examples/summarization/run_summarization.py @@ -31,9 +31,7 @@ Batch = namedtuple( def evaluate(args): tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True) - model = bertabs = BertAbs.from_pretrained( - "bertabs-finetuned-{}".format(args.finetuned_model) - ) + model = bertabs = BertAbs.from_pretrained("bertabs-finetuned-cnndm") bertabs.to(args.device) bertabs.eval() @@ -195,8 +193,8 @@ def main(): "--summaries_output_dir", default=None, type=str, - required=True, - help="The folder in wich the summaries should be written.", + required=False, + help="The folder in wich the summaries should be written. Defaults to the folder where the documents are", ) # EVALUATION options parser.add_argument( @@ -242,6 +240,9 @@ def main(): args = parser.parse_args() args.device = torch.device("cpu") if args.visible_gpus == -1 else torch.device("cuda") + if not args.summaries_output_dir: + args.summaries_output_dir = args.documents_dir + if not documents_dir_is_valid(args.documents_dir): raise FileNotFoundError( "We could not find the directory you specified for the documents to summarize, or it was empty. Please specify a valid path." diff --git a/examples/summarization/utils_summarization.py b/examples/summarization/utils_summarization.py index e7401b1754..1d8c436ac9 100644 --- a/examples/summarization/utils_summarization.py +++ b/examples/summarization/utils_summarization.py @@ -39,6 +39,8 @@ class SummarizationDataset(Dataset): self.documents = [] story_filenames_list = os.listdir(path) for story_filename in story_filenames_list: + if "summary" in story_filename: + continue path_to_story = os.path.join(path, story_filename) if not os.path.isfile(path_to_story): continue