test the full story processing
This commit is contained in:
@@ -87,9 +87,9 @@ class TextDataset(Dataset):
|
|||||||
path_to_stories = os.path.join(data_dir, dataset, "stories")
|
path_to_stories = os.path.join(data_dir, dataset, "stories")
|
||||||
assert os.path.isdir(path_to_stories)
|
assert os.path.isdir(path_to_stories)
|
||||||
|
|
||||||
stories_files = os.listdir(path_to_stories)
|
story_filenames_list = os.listdir(path_to_stories)
|
||||||
for story_file in stories_files:
|
for story_filename in story_filenames_list:
|
||||||
path_to_story = os.path.join(path_to_stories, "story_file")
|
path_to_story = os.path.join(path_to_stories, story_filename)
|
||||||
if not os.path.isfile(path_to_story):
|
if not os.path.isfile(path_to_story):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -97,16 +97,16 @@ class TextDataset(Dataset):
|
|||||||
try:
|
try:
|
||||||
raw_story = source.read()
|
raw_story = source.read()
|
||||||
story, summary = process_story(raw_story)
|
story, summary = process_story(raw_story)
|
||||||
except IndexError:
|
except IndexError: # skip ill-formed stories
|
||||||
continue
|
continue
|
||||||
|
|
||||||
story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story))
|
story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story))
|
||||||
summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary))
|
summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary))
|
||||||
story_seq, summary_seq = _fit_to_block_size(story, summary, block_size)
|
story_seq, summary_seq = _fit_to_block_size(story, summary, block_size)
|
||||||
example = tokenizer.add_special_token_sequence_pair(
|
|
||||||
story_seq, summary_seq
|
self.examples.append(
|
||||||
|
tokenizer.add_special_token_sequence_pair(story_seq, summary_seq)
|
||||||
)
|
)
|
||||||
self.examples.append(example)
|
|
||||||
|
|
||||||
logger.info("Saving features into cache file %s", cached_features_file)
|
logger.info("Saving features into cache file %s", cached_features_file)
|
||||||
with open(cached_features_file, "wb") as sink:
|
with open(cached_features_file, "wb") as sink:
|
||||||
@@ -120,8 +120,13 @@ class TextDataset(Dataset):
|
|||||||
|
|
||||||
|
|
||||||
def process_story(raw_story):
|
def process_story(raw_story):
|
||||||
""" Process the text contained in a story file.
|
""" Extract the story and summary from a story file.
|
||||||
Returns the story and the summary
|
|
||||||
|
Attributes:
|
||||||
|
raw_story (str): content of the story file as an utf-8 encoded string.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
IndexError: If the stoy is empty or contains no highlights.
|
||||||
"""
|
"""
|
||||||
file_lines = list(
|
file_lines = list(
|
||||||
filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
|
filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
|
||||||
@@ -158,7 +163,7 @@ def _add_missing_period(line):
|
|||||||
return line
|
return line
|
||||||
if line[-1] in END_TOKENS:
|
if line[-1] in END_TOKENS:
|
||||||
return line
|
return line
|
||||||
return line + " ."
|
return line + "."
|
||||||
|
|
||||||
|
|
||||||
def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
|
def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
|
||||||
@@ -169,6 +174,13 @@ def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
|
|||||||
block size of 512 this means limiting the source sequence's length to 384
|
block size of 512 this means limiting the source sequence's length to 384
|
||||||
and the target sequence's length to 128.
|
and the target sequence's length to 128.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
src_sequence (list): a list of ids that maps to the tokens of the
|
||||||
|
source sequence.
|
||||||
|
tgt_sequence (list): a list of ids that maps to the tokens of the
|
||||||
|
target sequence.
|
||||||
|
block_size (int): the model's block size.
|
||||||
|
|
||||||
[1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
|
[1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
|
||||||
Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
|
Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -14,21 +14,21 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from run_seq2seq_finetuning import _fit_to_block_size
|
from run_seq2seq_finetuning import _fit_to_block_size, process_story
|
||||||
|
|
||||||
|
|
||||||
class DataLoaderTest(unittest.TestCase):
|
class DataLoaderTest(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.block_size = 10
|
self.block_size = 10
|
||||||
|
|
||||||
def test_source_and_target_too_small(self):
|
def test_truncate_source_and_target_too_small(self):
|
||||||
""" When the sum of the lengths of the source and target sequences is
|
""" When the sum of the lengths of the source and target sequences is
|
||||||
smaller than the block size (minus the number of special tokens), skip the example. """
|
smaller than the block size (minus the number of special tokens), skip the example. """
|
||||||
src_seq = [1, 2, 3, 4]
|
src_seq = [1, 2, 3, 4]
|
||||||
tgt_seq = [5, 6]
|
tgt_seq = [5, 6]
|
||||||
self.assertEqual(_fit_to_block_size(src_seq, tgt_seq, self.block_size), None)
|
self.assertEqual(_fit_to_block_size(src_seq, tgt_seq, self.block_size), None)
|
||||||
|
|
||||||
def test_source_and_target_fit_exactly(self):
|
def test_truncate_source_and_target_fit_exactly(self):
|
||||||
""" When the sum of the lengths of the source and target sequences is
|
""" When the sum of the lengths of the source and target sequences is
|
||||||
equal to the block size (minus the number of special tokens), return the
|
equal to the block size (minus the number of special tokens), return the
|
||||||
sequences unchanged. """
|
sequences unchanged. """
|
||||||
@@ -38,27 +38,61 @@ class DataLoaderTest(unittest.TestCase):
|
|||||||
self.assertListEqual(src_seq, fitted_src)
|
self.assertListEqual(src_seq, fitted_src)
|
||||||
self.assertListEqual(tgt_seq, fitted_tgt)
|
self.assertListEqual(tgt_seq, fitted_tgt)
|
||||||
|
|
||||||
def test_source_too_big_target_ok(self):
|
def test_truncate_source_too_big_target_ok(self):
|
||||||
src_seq = [1, 2, 3, 4, 5, 6]
|
src_seq = [1, 2, 3, 4, 5, 6]
|
||||||
tgt_seq = [1, 2]
|
tgt_seq = [1, 2]
|
||||||
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
|
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
|
||||||
self.assertListEqual(fitted_src, [1, 2, 3, 4, 5])
|
self.assertListEqual(fitted_src, [1, 2, 3, 4, 5])
|
||||||
self.assertListEqual(fitted_tgt, fitted_tgt)
|
self.assertListEqual(fitted_tgt, fitted_tgt)
|
||||||
|
|
||||||
def test_target_too_big_source_ok(self):
|
def test_truncate_target_too_big_source_ok(self):
|
||||||
src_seq = [1, 2, 3, 4]
|
src_seq = [1, 2, 3, 4]
|
||||||
tgt_seq = [1, 2, 3, 4]
|
tgt_seq = [1, 2, 3, 4]
|
||||||
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
|
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
|
||||||
self.assertListEqual(fitted_src, src_seq)
|
self.assertListEqual(fitted_src, src_seq)
|
||||||
self.assertListEqual(fitted_tgt, [1, 2, 3])
|
self.assertListEqual(fitted_tgt, [1, 2, 3])
|
||||||
|
|
||||||
def test_source_and_target_too_big(self):
|
def test_truncate_source_and_target_too_big(self):
|
||||||
src_seq = [1, 2, 3, 4, 5, 6, 7]
|
src_seq = [1, 2, 3, 4, 5, 6, 7]
|
||||||
tgt_seq = [1, 2, 3, 4, 5, 6, 7]
|
tgt_seq = [1, 2, 3, 4, 5, 6, 7]
|
||||||
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
|
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
|
||||||
self.assertListEqual(fitted_src, [1, 2, 3, 4, 5])
|
self.assertListEqual(fitted_src, [1, 2, 3, 4, 5])
|
||||||
self.assertListEqual(fitted_tgt, [1, 2])
|
self.assertListEqual(fitted_tgt, [1, 2])
|
||||||
|
|
||||||
|
def test_process_story_no_highlights(self):
|
||||||
|
""" Processing a story with no highlights should raise an exception.
|
||||||
|
"""
|
||||||
|
raw_story = """It was the year of Our Lord one thousand seven hundred and
|
||||||
|
seventy-five.\n\nSpiritual revelations were conceded to England at that
|
||||||
|
favoured period, as at this."""
|
||||||
|
with self.assertRaises(IndexError):
|
||||||
|
process_story(raw_story)
|
||||||
|
|
||||||
|
def test_process_empty_story(self):
|
||||||
|
""" An empty story should also raise and exception.
|
||||||
|
"""
|
||||||
|
raw_story = ""
|
||||||
|
with self.assertRaises(IndexError):
|
||||||
|
process_story(raw_story)
|
||||||
|
|
||||||
|
def test_story_with_missing_period(self):
|
||||||
|
raw_story = (
|
||||||
|
"It was the year of Our Lord one thousand seven hundred and "
|
||||||
|
"seventy-five\n\nSpiritual revelations were conceded to England "
|
||||||
|
"at that favoured period, as at this.\n@highlight\n\nIt was the best of times"
|
||||||
|
)
|
||||||
|
story, summary = process_story(raw_story)
|
||||||
|
|
||||||
|
expected_story = (
|
||||||
|
"It was the year of Our Lord one thousand seven hundred and "
|
||||||
|
"seventy-five. Spiritual revelations were conceded to England at that "
|
||||||
|
"favoured period, as at this."
|
||||||
|
)
|
||||||
|
self.assertEqual(expected_story, story)
|
||||||
|
|
||||||
|
expected_summary = "It was the best of times."
|
||||||
|
self.assertEqual(expected_summary, summary)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user