default output dir to documents dir
This commit is contained in:
committed by
Julien Chaumond
parent
693606a75c
commit
3a9a9f7861
@@ -31,9 +31,7 @@ Batch = namedtuple(
|
|||||||
|
|
||||||
def evaluate(args):
|
def evaluate(args):
|
||||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
|
||||||
model = bertabs = BertAbs.from_pretrained(
|
model = bertabs = BertAbs.from_pretrained("bertabs-finetuned-cnndm")
|
||||||
"bertabs-finetuned-{}".format(args.finetuned_model)
|
|
||||||
)
|
|
||||||
bertabs.to(args.device)
|
bertabs.to(args.device)
|
||||||
bertabs.eval()
|
bertabs.eval()
|
||||||
|
|
||||||
@@ -195,8 +193,8 @@ def main():
|
|||||||
"--summaries_output_dir",
|
"--summaries_output_dir",
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=False,
|
||||||
help="The folder in wich the summaries should be written.",
|
help="The folder in wich the summaries should be written. Defaults to the folder where the documents are",
|
||||||
)
|
)
|
||||||
# EVALUATION options
|
# EVALUATION options
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -242,6 +240,9 @@ def main():
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.device = torch.device("cpu") if args.visible_gpus == -1 else torch.device("cuda")
|
args.device = torch.device("cpu") if args.visible_gpus == -1 else torch.device("cuda")
|
||||||
|
|
||||||
|
if not args.summaries_output_dir:
|
||||||
|
args.summaries_output_dir = args.documents_dir
|
||||||
|
|
||||||
if not documents_dir_is_valid(args.documents_dir):
|
if not documents_dir_is_valid(args.documents_dir):
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
"We could not find the directory you specified for the documents to summarize, or it was empty. Please specify a valid path."
|
"We could not find the directory you specified for the documents to summarize, or it was empty. Please specify a valid path."
|
||||||
|
|||||||
@@ -39,6 +39,8 @@ class SummarizationDataset(Dataset):
|
|||||||
self.documents = []
|
self.documents = []
|
||||||
story_filenames_list = os.listdir(path)
|
story_filenames_list = os.listdir(path)
|
||||||
for story_filename in story_filenames_list:
|
for story_filename in story_filenames_list:
|
||||||
|
if "summary" in story_filename:
|
||||||
|
continue
|
||||||
path_to_story = os.path.join(path, story_filename)
|
path_to_story = os.path.join(path, story_filename)
|
||||||
if not os.path.isfile(path_to_story):
|
if not os.path.isfile(path_to_story):
|
||||||
continue
|
continue
|
||||||
|
|||||||
Reference in New Issue
Block a user