Black 20 release
This commit is contained in:
@@ -11,7 +11,7 @@ from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class CNNDMDataset(Dataset):
|
||||
""" Abstracts the dataset used to train seq2seq models.
|
||||
"""Abstracts the dataset used to train seq2seq models.
|
||||
|
||||
The class will process the documents that are located in the specified
|
||||
folder. The preprocessing will work on any document that is reasonably
|
||||
@@ -31,7 +31,7 @@ class CNNDMDataset(Dataset):
|
||||
"""
|
||||
|
||||
def __init__(self, path="", prefix="train"):
|
||||
""" We initialize the class by listing all the documents to summarize.
|
||||
"""We initialize the class by listing all the documents to summarize.
|
||||
Files are not read in memory due to the size of some datasets (like CNN/DailyMail).
|
||||
"""
|
||||
assert os.path.isdir(path)
|
||||
@@ -60,7 +60,7 @@ class CNNDMDataset(Dataset):
|
||||
|
||||
|
||||
def process_story(raw_story):
|
||||
""" Extract the story and summary from a story file.
|
||||
"""Extract the story and summary from a story file.
|
||||
|
||||
Arguments:
|
||||
raw_story (str): content of the story file as an utf-8 encoded string.
|
||||
@@ -108,7 +108,7 @@ def _add_missing_period(line):
|
||||
|
||||
|
||||
def truncate_or_pad(sequence, block_size, pad_token_id):
|
||||
""" Adapt the source and target sequences' lengths to the block size.
|
||||
"""Adapt the source and target sequences' lengths to the block size.
|
||||
If the sequence is shorter we append padding token to the right of the sequence.
|
||||
"""
|
||||
if len(sequence) > block_size:
|
||||
@@ -119,8 +119,8 @@ def truncate_or_pad(sequence, block_size, pad_token_id):
|
||||
|
||||
|
||||
def build_mask(sequence, pad_token_id):
|
||||
""" Builds the mask. The attention mechanism will only attend to positions
|
||||
with value 1. """
|
||||
"""Builds the mask. The attention mechanism will only attend to positions
|
||||
with value 1."""
|
||||
mask = torch.ones_like(sequence)
|
||||
idx_pad_tokens = sequence == pad_token_id
|
||||
mask[idx_pad_tokens] = 0
|
||||
@@ -128,7 +128,7 @@ def build_mask(sequence, pad_token_id):
|
||||
|
||||
|
||||
def encode_for_summarization(story_lines, summary_lines, tokenizer):
|
||||
""" Encode the story and summary lines, and join them
|
||||
"""Encode the story and summary lines, and join them
|
||||
as specified in [1] by using `[SEP] [CLS]` tokens to separate
|
||||
sentences.
|
||||
"""
|
||||
@@ -141,7 +141,7 @@ def encode_for_summarization(story_lines, summary_lines, tokenizer):
|
||||
|
||||
|
||||
def compute_token_type_ids(batch, separator_token_id):
|
||||
""" Segment embeddings as described in [1]
|
||||
"""Segment embeddings as described in [1]
|
||||
|
||||
The values {0,1} were found in the repository [2].
|
||||
|
||||
|
||||
Reference in New Issue
Block a user