test the full story processing
This commit is contained in:
@@ -14,21 +14,21 @@
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
from run_seq2seq_finetuning import _fit_to_block_size
|
||||
from run_seq2seq_finetuning import _fit_to_block_size, process_story
|
||||
|
||||
|
||||
class DataLoaderTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.block_size = 10
|
||||
|
||||
def test_source_and_target_too_small(self):
|
||||
def test_truncate_source_and_target_too_small(self):
|
||||
""" When the sum of the lengths of the source and target sequences is
|
||||
smaller than the block size (minus the number of special tokens), skip the example. """
|
||||
src_seq = [1, 2, 3, 4]
|
||||
tgt_seq = [5, 6]
|
||||
self.assertEqual(_fit_to_block_size(src_seq, tgt_seq, self.block_size), None)
|
||||
|
||||
def test_source_and_target_fit_exactly(self):
|
||||
def test_truncate_source_and_target_fit_exactly(self):
|
||||
""" When the sum of the lengths of the source and target sequences is
|
||||
equal to the block size (minus the number of special tokens), return the
|
||||
sequences unchanged. """
|
||||
@@ -38,27 +38,61 @@ class DataLoaderTest(unittest.TestCase):
|
||||
self.assertListEqual(src_seq, fitted_src)
|
||||
self.assertListEqual(tgt_seq, fitted_tgt)
|
||||
|
||||
def test_source_too_big_target_ok(self):
|
||||
def test_truncate_source_too_big_target_ok(self):
|
||||
src_seq = [1, 2, 3, 4, 5, 6]
|
||||
tgt_seq = [1, 2]
|
||||
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
|
||||
self.assertListEqual(fitted_src, [1, 2, 3, 4, 5])
|
||||
self.assertListEqual(fitted_tgt, fitted_tgt)
|
||||
|
||||
def test_target_too_big_source_ok(self):
|
||||
def test_truncate_target_too_big_source_ok(self):
|
||||
src_seq = [1, 2, 3, 4]
|
||||
tgt_seq = [1, 2, 3, 4]
|
||||
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
|
||||
self.assertListEqual(fitted_src, src_seq)
|
||||
self.assertListEqual(fitted_tgt, [1, 2, 3])
|
||||
|
||||
def test_source_and_target_too_big(self):
|
||||
def test_truncate_source_and_target_too_big(self):
|
||||
src_seq = [1, 2, 3, 4, 5, 6, 7]
|
||||
tgt_seq = [1, 2, 3, 4, 5, 6, 7]
|
||||
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
|
||||
self.assertListEqual(fitted_src, [1, 2, 3, 4, 5])
|
||||
self.assertListEqual(fitted_tgt, [1, 2])
|
||||
|
||||
def test_process_story_no_highlights(self):
|
||||
""" Processing a story with no highlights should raise an exception.
|
||||
"""
|
||||
raw_story = """It was the year of Our Lord one thousand seven hundred and
|
||||
seventy-five.\n\nSpiritual revelations were conceded to England at that
|
||||
favoured period, as at this."""
|
||||
with self.assertRaises(IndexError):
|
||||
process_story(raw_story)
|
||||
|
||||
def test_process_empty_story(self):
|
||||
""" An empty story should also raise and exception.
|
||||
"""
|
||||
raw_story = ""
|
||||
with self.assertRaises(IndexError):
|
||||
process_story(raw_story)
|
||||
|
||||
def test_story_with_missing_period(self):
|
||||
raw_story = (
|
||||
"It was the year of Our Lord one thousand seven hundred and "
|
||||
"seventy-five\n\nSpiritual revelations were conceded to England "
|
||||
"at that favoured period, as at this.\n@highlight\n\nIt was the best of times"
|
||||
)
|
||||
story, summary = process_story(raw_story)
|
||||
|
||||
expected_story = (
|
||||
"It was the year of Our Lord one thousand seven hundred and "
|
||||
"seventy-five. Spiritual revelations were conceded to England at that "
|
||||
"favoured period, as at this."
|
||||
)
|
||||
self.assertEqual(expected_story, story)
|
||||
|
||||
expected_summary = "It was the best of times."
|
||||
self.assertEqual(expected_summary, summary)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user