truncation function is fully tested

This commit is contained in:
Rémi Louf
2019-10-15 14:39:56 +02:00
parent 260ac7d9a8
commit 22e1af6859
2 changed files with 74 additions and 59 deletions

View File

@@ -14,50 +14,50 @@
# limitations under the License.
import unittest
from .run_seq2seq_finetuning import process_story, _fit_to_block_size
from run_seq2seq_finetuning import _fit_to_block_size
class DataLoaderTest(unittest.TestCase):
def __init__(self, block_size=10):
self.block_size = block_size
def setUp(self):
self.block_size = 10
def source_and_target_too_small(self):
def test_source_and_target_too_small(self):
""" When the sum of the lengths of the source and target sequences is
smaller than the block size (minus the number of special tokens), skip the example. """
src_seq = [1, 2, 3, 4]
tgt_seq = [5, 6]
self.assertEqual(_fit_to_block_size(src_seq, tgt_seq, self.block_size), None)
def source_and_target_fit_exactly(self):
def test_source_and_target_fit_exactly(self):
""" When the sum of the lengths of the source and target sequences is
equal to the block size (minus the number of special tokens), return the
sequences unchanged. """
src_seq = [1, 2, 3, 4]
tgt_seq = [5, 6, 7]
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
self.assertListEqual(src_seq == fitted_src)
self.assertListEqual(tgt_seq == fitted_tgt)
self.assertListEqual(src_seq, fitted_src)
self.assertListEqual(tgt_seq, fitted_tgt)
def source_too_big_target_ok(self):
def test_source_too_big_target_ok(self):
src_seq = [1, 2, 3, 4, 5, 6]
tgt_seq = [1, 2]
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
self.assertListEqual(src_seq == [1, 2, 3, 4, 5])
self.assertListEqual(tgt_seq == fitted_tgt)
self.assertListEqual(fitted_src, [1, 2, 3, 4, 5])
self.assertListEqual(fitted_tgt, fitted_tgt)
def target_too_big_source_ok(self):
def test_target_too_big_source_ok(self):
src_seq = [1, 2, 3, 4]
tgt_seq = [1, 2, 3, 4]
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
self.assertListEqual(src_seq == src_seq)
self.assertListEqual(tgt_seq == [1, 2, 3])
self.assertListEqual(fitted_src, src_seq)
self.assertListEqual(fitted_tgt, [1, 2, 3])
def source_and_target_too_big(self):
def test_source_and_target_too_big(self):
src_seq = [1, 2, 3, 4, 5, 6, 7]
tgt_seq = [1, 2, 3, 4, 5, 6, 7]
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
self.assertListEqual(src_seq == [1, 2, 3, 4, 5])
self.assertListEqual(tgt_seq == [1, 2])
self.assertListEqual(fitted_src, [1, 2, 3, 4, 5])
self.assertListEqual(fitted_tgt, [1, 2])
if __name__ == "__main__":