Black 20 release

This commit is contained in:
Lysandre
2020-08-26 17:20:22 +02:00
parent e78c110338
commit a75c64d80c
191 changed files with 4807 additions and 3503 deletions

View File

@@ -99,7 +99,7 @@ def evaluate(args):
def save_summaries(summaries, path, original_document_name):
""" Write the summaries in fies that are prefixed by the original
"""Write the summaries in fies that are prefixed by the original
files' name with the `_summary` appended.
Attributes:
@@ -125,7 +125,7 @@ def save_summaries(summaries, path, original_document_name):
def format_summary(translation):
""" Transforms the output of the `from_batch` function
"""Transforms the output of the `from_batch` function
into nicely formatted summaries.
"""
raw_summary, _, _ = translation
@@ -190,7 +190,12 @@ def build_data_iterator(args, tokenizer):
def collate_fn(data):
return collate(data, tokenizer, block_size=512, device=args.device)
iterator = DataLoader(dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_fn,)
iterator = DataLoader(
dataset,
sampler=sampler,
batch_size=args.batch_size,
collate_fn=collate_fn,
)
return iterator
@@ -201,7 +206,7 @@ def load_and_cache_examples(args, tokenizer):
def collate(data, tokenizer, block_size, device):
""" Collate formats the data passed to the data loader.
"""Collate formats the data passed to the data loader.
In particular we tokenize the data batch after batch to avoid keeping them
all in memory. We output the data as a namedtuple to fit the original BertAbs's
@@ -231,7 +236,7 @@ def collate(data, tokenizer, block_size, device):
def decode_summary(summary_tokens, tokenizer):
""" Decode the summary and return it in a format
"""Decode the summary and return it in a format
suitable for evaluation.
"""
summary_tokens = summary_tokens.to("cpu").numpy()
@@ -242,8 +247,7 @@ def decode_summary(summary_tokens, tokenizer):
def main():
""" The main function defines the interface with the users.
"""
"""The main function defines the interface with the users."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--documents_dir",
@@ -268,23 +272,41 @@ def main():
)
# EVALUATION options
parser.add_argument(
"--no_cuda", default=False, type=bool, help="Whether to force the execution on CPU.",
"--no_cuda",
default=False,
type=bool,
help="Whether to force the execution on CPU.",
)
parser.add_argument(
"--batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.",
"--batch_size",
default=4,
type=int,
help="Batch size per GPU/CPU for training.",
)
# BEAM SEARCH arguments
parser.add_argument(
"--min_length", default=50, type=int, help="Minimum number of tokens for the summaries.",
"--min_length",
default=50,
type=int,
help="Minimum number of tokens for the summaries.",
)
parser.add_argument(
"--max_length", default=200, type=int, help="Maixmum number of tokens for the summaries.",
"--max_length",
default=200,
type=int,
help="Maixmum number of tokens for the summaries.",
)
parser.add_argument(
"--beam_size", default=5, type=int, help="The number of beams to start with for each example.",
"--beam_size",
default=5,
type=int,
help="The number of beams to start with for each example.",
)
parser.add_argument(
"--alpha", default=0.95, type=float, help="The value of alpha for the length penalty in the beam search.",
"--alpha",
default=0.95,
type=float,
help="The value of alpha for the length penalty in the beam search.",
)
parser.add_argument(
"--block_trigram",