From 1aec940587255083b2451fc18aa604de29c1188c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Tue, 15 Oct 2019 15:18:07 +0200 Subject: [PATCH] test the full story processing --- examples/run_seq2seq_finetuning.py | 32 +++++++++++------ examples/run_seq2seq_finetuning_test.py | 46 +++++++++++++++++++++---- 2 files changed, 62 insertions(+), 16 deletions(-) diff --git a/examples/run_seq2seq_finetuning.py b/examples/run_seq2seq_finetuning.py index e926523a17..f05a5847ed 100644 --- a/examples/run_seq2seq_finetuning.py +++ b/examples/run_seq2seq_finetuning.py @@ -87,9 +87,9 @@ class TextDataset(Dataset): path_to_stories = os.path.join(data_dir, dataset, "stories") assert os.path.isdir(path_to_stories) - stories_files = os.listdir(path_to_stories) - for story_file in stories_files: - path_to_story = os.path.join(path_to_stories, "story_file") + story_filenames_list = os.listdir(path_to_stories) + for story_filename in story_filenames_list: + path_to_story = os.path.join(path_to_stories, story_filename) if not os.path.isfile(path_to_story): continue @@ -97,16 +97,16 @@ class TextDataset(Dataset): try: raw_story = source.read() story, summary = process_story(raw_story) - except IndexError: + except IndexError: # skip ill-formed stories continue story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story)) summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary)) story_seq, summary_seq = _fit_to_block_size(story, summary, block_size) - example = tokenizer.add_special_token_sequence_pair( - story_seq, summary_seq + + self.examples.append( + tokenizer.add_special_token_sequence_pair(story_seq, summary_seq) ) - self.examples.append(example) logger.info("Saving features into cache file %s", cached_features_file) with open(cached_features_file, "wb") as sink: @@ -120,8 +120,13 @@ class TextDataset(Dataset): def process_story(raw_story): - """ Process the text contained in a story file. - Returns the story and the summary + """ Extract the story and summary from a story file. + + Attributes: + raw_story (str): content of the story file as an utf-8 encoded string. + + Raises: + IndexError: If the stoy is empty or contains no highlights. """ file_lines = list( filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]) @@ -158,7 +163,7 @@ def _add_missing_period(line): return line if line[-1] in END_TOKENS: return line - return line + " ." + return line + "." def _fit_to_block_size(src_sequence, tgt_sequence, block_size): @@ -169,6 +174,13 @@ def _fit_to_block_size(src_sequence, tgt_sequence, block_size): block size of 512 this means limiting the source sequence's length to 384 and the target sequence's length to 128. + Attributes: + src_sequence (list): a list of ids that maps to the tokens of the + source sequence. + tgt_sequence (list): a list of ids that maps to the tokens of the + target sequence. + block_size (int): the model's block size. + [1] Dong, Li, et al. "Unified Language Model Pre-training for Natural Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019). """ diff --git a/examples/run_seq2seq_finetuning_test.py b/examples/run_seq2seq_finetuning_test.py index aff39f25b8..e59f016da4 100644 --- a/examples/run_seq2seq_finetuning_test.py +++ b/examples/run_seq2seq_finetuning_test.py @@ -14,21 +14,21 @@ # limitations under the License. import unittest -from run_seq2seq_finetuning import _fit_to_block_size +from run_seq2seq_finetuning import _fit_to_block_size, process_story class DataLoaderTest(unittest.TestCase): def setUp(self): self.block_size = 10 - def test_source_and_target_too_small(self): + def test_truncate_source_and_target_too_small(self): """ When the sum of the lengths of the source and target sequences is smaller than the block size (minus the number of special tokens), skip the example. """ src_seq = [1, 2, 3, 4] tgt_seq = [5, 6] self.assertEqual(_fit_to_block_size(src_seq, tgt_seq, self.block_size), None) - def test_source_and_target_fit_exactly(self): + def test_truncate_source_and_target_fit_exactly(self): """ When the sum of the lengths of the source and target sequences is equal to the block size (minus the number of special tokens), return the sequences unchanged. """ @@ -38,27 +38,61 @@ class DataLoaderTest(unittest.TestCase): self.assertListEqual(src_seq, fitted_src) self.assertListEqual(tgt_seq, fitted_tgt) - def test_source_too_big_target_ok(self): + def test_truncate_source_too_big_target_ok(self): src_seq = [1, 2, 3, 4, 5, 6] tgt_seq = [1, 2] fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size) self.assertListEqual(fitted_src, [1, 2, 3, 4, 5]) self.assertListEqual(fitted_tgt, fitted_tgt) - def test_target_too_big_source_ok(self): + def test_truncate_target_too_big_source_ok(self): src_seq = [1, 2, 3, 4] tgt_seq = [1, 2, 3, 4] fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size) self.assertListEqual(fitted_src, src_seq) self.assertListEqual(fitted_tgt, [1, 2, 3]) - def test_source_and_target_too_big(self): + def test_truncate_source_and_target_too_big(self): src_seq = [1, 2, 3, 4, 5, 6, 7] tgt_seq = [1, 2, 3, 4, 5, 6, 7] fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size) self.assertListEqual(fitted_src, [1, 2, 3, 4, 5]) self.assertListEqual(fitted_tgt, [1, 2]) + def test_process_story_no_highlights(self): + """ Processing a story with no highlights should raise an exception. + """ + raw_story = """It was the year of Our Lord one thousand seven hundred and + seventy-five.\n\nSpiritual revelations were conceded to England at that + favoured period, as at this.""" + with self.assertRaises(IndexError): + process_story(raw_story) + + def test_process_empty_story(self): + """ An empty story should also raise and exception. + """ + raw_story = "" + with self.assertRaises(IndexError): + process_story(raw_story) + + def test_story_with_missing_period(self): + raw_story = ( + "It was the year of Our Lord one thousand seven hundred and " + "seventy-five\n\nSpiritual revelations were conceded to England " + "at that favoured period, as at this.\n@highlight\n\nIt was the best of times" + ) + story, summary = process_story(raw_story) + + expected_story = ( + "It was the year of Our Lord one thousand seven hundred and " + "seventy-five. Spiritual revelations were conceded to England at that " + "favoured period, as at this." + ) + self.assertEqual(expected_story, story) + + expected_summary = "It was the best of times." + self.assertEqual(expected_summary, summary) + if __name__ == "__main__": unittest.main()