Black 20 release

This commit is contained in:
Lysandre
2020-08-26 17:20:22 +02:00
parent e78c110338
commit a75c64d80c
191 changed files with 4807 additions and 3503 deletions

View File

@@ -11,7 +11,7 @@ from torch.utils.data import Dataset
class CNNDMDataset(Dataset):
""" Abstracts the dataset used to train seq2seq models.
"""Abstracts the dataset used to train seq2seq models.
The class will process the documents that are located in the specified
folder. The preprocessing will work on any document that is reasonably
@@ -31,7 +31,7 @@ class CNNDMDataset(Dataset):
"""
def __init__(self, path="", prefix="train"):
""" We initialize the class by listing all the documents to summarize.
"""We initialize the class by listing all the documents to summarize.
Files are not read in memory due to the size of some datasets (like CNN/DailyMail).
"""
assert os.path.isdir(path)
@@ -60,7 +60,7 @@ class CNNDMDataset(Dataset):
def process_story(raw_story):
""" Extract the story and summary from a story file.
"""Extract the story and summary from a story file.
Arguments:
raw_story (str): content of the story file as an utf-8 encoded string.
@@ -108,7 +108,7 @@ def _add_missing_period(line):
def truncate_or_pad(sequence, block_size, pad_token_id):
""" Adapt the source and target sequences' lengths to the block size.
"""Adapt the source and target sequences' lengths to the block size.
If the sequence is shorter we append padding token to the right of the sequence.
"""
if len(sequence) > block_size:
@@ -119,8 +119,8 @@ def truncate_or_pad(sequence, block_size, pad_token_id):
def build_mask(sequence, pad_token_id):
""" Builds the mask. The attention mechanism will only attend to positions
with value 1. """
"""Builds the mask. The attention mechanism will only attend to positions
with value 1."""
mask = torch.ones_like(sequence)
idx_pad_tokens = sequence == pad_token_id
mask[idx_pad_tokens] = 0
@@ -128,7 +128,7 @@ def build_mask(sequence, pad_token_id):
def encode_for_summarization(story_lines, summary_lines, tokenizer):
""" Encode the story and summary lines, and join them
"""Encode the story and summary lines, and join them
as specified in [1] by using `[SEP] [CLS]` tokens to separate
sentences.
"""
@@ -141,7 +141,7 @@ def encode_for_summarization(story_lines, summary_lines, tokenizer):
def compute_token_type_ids(batch, separator_token_id):
""" Segment embeddings as described in [1]
"""Segment embeddings as described in [1]
The values {0,1} were found in the repository [2].