test the full story processing
This commit is contained in:
@@ -87,9 +87,9 @@ class TextDataset(Dataset):
|
||||
path_to_stories = os.path.join(data_dir, dataset, "stories")
|
||||
assert os.path.isdir(path_to_stories)
|
||||
|
||||
stories_files = os.listdir(path_to_stories)
|
||||
for story_file in stories_files:
|
||||
path_to_story = os.path.join(path_to_stories, "story_file")
|
||||
story_filenames_list = os.listdir(path_to_stories)
|
||||
for story_filename in story_filenames_list:
|
||||
path_to_story = os.path.join(path_to_stories, story_filename)
|
||||
if not os.path.isfile(path_to_story):
|
||||
continue
|
||||
|
||||
@@ -97,16 +97,16 @@ class TextDataset(Dataset):
|
||||
try:
|
||||
raw_story = source.read()
|
||||
story, summary = process_story(raw_story)
|
||||
except IndexError:
|
||||
except IndexError: # skip ill-formed stories
|
||||
continue
|
||||
|
||||
story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story))
|
||||
summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary))
|
||||
story_seq, summary_seq = _fit_to_block_size(story, summary, block_size)
|
||||
example = tokenizer.add_special_token_sequence_pair(
|
||||
story_seq, summary_seq
|
||||
|
||||
self.examples.append(
|
||||
tokenizer.add_special_token_sequence_pair(story_seq, summary_seq)
|
||||
)
|
||||
self.examples.append(example)
|
||||
|
||||
logger.info("Saving features into cache file %s", cached_features_file)
|
||||
with open(cached_features_file, "wb") as sink:
|
||||
@@ -120,8 +120,13 @@ class TextDataset(Dataset):
|
||||
|
||||
|
||||
def process_story(raw_story):
|
||||
""" Process the text contained in a story file.
|
||||
Returns the story and the summary
|
||||
""" Extract the story and summary from a story file.
|
||||
|
||||
Attributes:
|
||||
raw_story (str): content of the story file as an utf-8 encoded string.
|
||||
|
||||
Raises:
|
||||
IndexError: If the stoy is empty or contains no highlights.
|
||||
"""
|
||||
file_lines = list(
|
||||
filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
|
||||
@@ -158,7 +163,7 @@ def _add_missing_period(line):
|
||||
return line
|
||||
if line[-1] in END_TOKENS:
|
||||
return line
|
||||
return line + " ."
|
||||
return line + "."
|
||||
|
||||
|
||||
def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
|
||||
@@ -169,6 +174,13 @@ def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
|
||||
block size of 512 this means limiting the source sequence's length to 384
|
||||
and the target sequence's length to 128.
|
||||
|
||||
Attributes:
|
||||
src_sequence (list): a list of ids that maps to the tokens of the
|
||||
source sequence.
|
||||
tgt_sequence (list): a list of ids that maps to the tokens of the
|
||||
target sequence.
|
||||
block_size (int): the model's block size.
|
||||
|
||||
[1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
|
||||
Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user