From 22e1af68596690558cd8df45b6bc75e665cc1c1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Tue, 15 Oct 2019 14:39:56 +0200 Subject: [PATCH] truncation function is fully tested --- examples/run_seq2seq_finetuning.py | 101 ++++++++++++++---------- examples/run_seq2seq_finetuning_test.py | 32 ++++---- 2 files changed, 74 insertions(+), 59 deletions(-) diff --git a/examples/run_seq2seq_finetuning.py b/examples/run_seq2seq_finetuning.py index 1f247ab25b..e926523a17 100644 --- a/examples/run_seq2seq_finetuning.py +++ b/examples/run_seq2seq_finetuning.py @@ -41,7 +41,7 @@ import numpy as np import torch from torch.utils.data import Dataset -from transformers import BertConfig, Bert2Rnd, BertTokenizer +from transformers import BertTokenizer logger = logging.getLogger(__name__) @@ -57,19 +57,23 @@ class TextDataset(Dataset): CNN/Daily News: - The CNN/Daily News raw datasets are downloaded from [1]. The stories are stored in different files; the summary appears at the end of the story as - sentences that are prefixed by the special `@highlight` line. To process the - data, untar both datasets in the same folder, and pass the path to this + The CNN/Daily News raw datasets are downloaded from [1]. The stories are + stored in different files; the summary appears at the end of the story as + sentences that are prefixed by the special `@highlight` line. To process + the data, untar both datasets in the same folder, and pass the path to this folder as the "data_dir argument. The formatting code was inspired by [2]. [1] https://cs.nyu.edu/~kcho/ [2] https://github.com/abisee/cnn-dailymail/ """ - def __init_(self, tokenizer, data_dir='', block_size=512): + + def __init_(self, tokenizer, data_dir="", block_size=512): assert os.path.isdir(data_dir) # Load features that have already been computed if present - cached_features_file = os.path.join(directory, "cached_lm_{}_{}".format(block_size, data_dir)) + cached_features_file = os.path.join( + data_dir, "cached_lm_{}_{}".format(block_size, data_dir) + ) if os.path.exists(cached_features_file): logger.info("Loading features from cached file %s", cached_features_file) with open(cached_features_file, "rb") as source: @@ -78,7 +82,7 @@ class TextDataset(Dataset): logger.info("Creating features from dataset at %s", data_dir) - datasets = ['cnn', 'dailymail'] + datasets = ["cnn", "dailymail"] for dataset in datasets: path_to_stories = os.path.join(data_dir, dataset, "stories") assert os.path.isdir(path_to_stories) @@ -99,7 +103,9 @@ class TextDataset(Dataset): story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story)) summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary)) story_seq, summary_seq = _fit_to_block_size(story, summary, block_size) - example = tokenizer.add_special_token_sequence_pair(story_seq, summary_seq) + example = tokenizer.add_special_token_sequence_pair( + story_seq, summary_seq + ) self.examples.append(example) logger.info("Saving features into cache file %s", cached_features_file) @@ -117,7 +123,9 @@ def process_story(raw_story): """ Process the text contained in a story file. Returns the story and the summary """ - file_lines = list(filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])) + file_lines = list( + filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]) + ) # for some unknown reason some lines miss a period, add it file_lines = [_add_missing_period(line) for line in file_lines] @@ -145,7 +153,7 @@ def process_story(raw_story): def _add_missing_period(line): - END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', u'\u2019', u'\u2019', ")"] + END_TOKENS = [".", "!", "?", "...", "'", "`", '"', u"\u2019", u"\u2019", ")"] if line.startswith("@highlight"): return line if line[-1] in END_TOKENS: @@ -154,34 +162,35 @@ def _add_missing_period(line): def _fit_to_block_size(src_sequence, tgt_sequence, block_size): - """ Concatenate the sequences and adapt their lengths to the block size. + """ Adapt the source and target sequences' lengths to the block size. - Following [1] we truncate the source and target + tokens sequences so they fit - in the block size. If the concatenated sequence is longer than 512 we follow - the 75%/25% rule in [1]: limit the source sequence's length to 384 and the - target sequence's length to 128. + If the concatenated sequence (source + target + 3 special tokens) would be + longer than the block size we use the 75% / 25% rule followed in [1]. For a + block size of 512 this means limiting the source sequence's length to 384 + and the target sequence's length to 128. [1] Dong, Li, et al. "Unified Language Model Pre-training for Natural Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019). """ SRC_MAX_LENGTH = int(0.75 * block_size) - 2 # CLS and EOS token - TGT_MAX_LENGTH = block_size - SRC_MAX_LENGTH - 1 # EOS token + TGT_MAX_LENGTH = block_size - (SRC_MAX_LENGTH + 2) - 1 # EOS token - # we dump the examples that are too small to fit in the block size for the + # We dump the examples that are too small to fit in the block size for the # sake of simplicity. You can modify this by adding model-specific padding. - if len(src_sequence) + len(src_sequence) + 3 < block_size: + if len(src_sequence) + len(tgt_sequence) + 3 < block_size: return None - # the source sequence has `[SEP_i]` special tokens with i \in [0,9]. We keep them for now. if len(src_sequence) > SRC_MAX_LENGTH: if len(tgt_sequence) > TGT_MAX_LENGTH: src_sequence = src_sequence[:SRC_MAX_LENGTH] tgt_sequence = tgt_sequence[:TGT_MAX_LENGTH] else: - src_sequence = src_sequence[block_size - len(tgt_sequence) - 3] + remain_size = block_size - len(tgt_sequence) - 3 + src_sequence = src_sequence[:remain_size] else: if len(tgt_sequence) > TGT_MAX_LENGTH: - tgt_sequence = tgt_sequence[block_size - len(src_sequence) - 3] + remain_size = block_size - len(src_sequence) - 3 + tgt_sequence = tgt_sequence[:remain_size] return src_sequence, tgt_sequence @@ -200,44 +209,50 @@ def main(): parser = argparse.ArgumentParser() # Required parameters - parser.add_argument("--data_dir", - default=None, - type=str, - required=True, - help="The input training data file (a text file).") - parser.add_argument("--output_dir", - default=None, - type=str, - required=True, - help="The output directory where the model predictions and checkpoints will be written.") + parser.add_argument( + "--data_dir", + default=None, + type=str, + required=True, + help="The input training data file (a text file).", + ) + parser.add_argument( + "--output_dir", + default=None, + type=str, + required=True, + help="The output directory where the model predictions and checkpoints will be written.", + ) # Optional parameters - parser.add_argument("--model_name_or_path", - default="bert-base-cased", - type=str, - help="The model checkpoint for weights initialization.") + parser.add_argument( + "--model_name_or_path", + default="bert-base-cased", + type=str, + help="The model checkpoint for weights initialization.", + ) parser.add_argument("--seed", default=42, type=int) args = parser.parse_args() # Set up training device - device = torch.device("cpu") + # device = torch.device("cpu") # Set seed set_seed(args) # Load pretrained model and tokenizer - config_class, model_class, tokenizer_class = BertConfig, Bert2Rnd, BertTokenizer - config = config_class.from_pretrained(args.model_name_or_path) + tokenizer_class = BertTokenizer + # config = config_class.from_pretrained(args.model_name_or_path) tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) - model = model_class.from_pretrained(args.model_name_or_path, config=config) - model.to(device) + # model = model_class.from_pretrained(args.model_name_or_path, config=config) + # model.to(device) logger.info("Training/evaluation parameters %s", args) # Training - train_dataset = load_and_cache_examples(args, tokenizer) - global_step, tr_loss = train(args, train_dataset, model, tokenizer) - logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) + _ = load_and_cache_examples(args, tokenizer) + # global_step, tr_loss = train(args, train_dataset, model, tokenizer) + # logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) if __name__ == "__main__": diff --git a/examples/run_seq2seq_finetuning_test.py b/examples/run_seq2seq_finetuning_test.py index 34d9add10d..aff39f25b8 100644 --- a/examples/run_seq2seq_finetuning_test.py +++ b/examples/run_seq2seq_finetuning_test.py @@ -14,50 +14,50 @@ # limitations under the License. import unittest -from .run_seq2seq_finetuning import process_story, _fit_to_block_size +from run_seq2seq_finetuning import _fit_to_block_size class DataLoaderTest(unittest.TestCase): - def __init__(self, block_size=10): - self.block_size = block_size + def setUp(self): + self.block_size = 10 - def source_and_target_too_small(self): + def test_source_and_target_too_small(self): """ When the sum of the lengths of the source and target sequences is smaller than the block size (minus the number of special tokens), skip the example. """ src_seq = [1, 2, 3, 4] tgt_seq = [5, 6] self.assertEqual(_fit_to_block_size(src_seq, tgt_seq, self.block_size), None) - def source_and_target_fit_exactly(self): + def test_source_and_target_fit_exactly(self): """ When the sum of the lengths of the source and target sequences is equal to the block size (minus the number of special tokens), return the sequences unchanged. """ src_seq = [1, 2, 3, 4] tgt_seq = [5, 6, 7] fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size) - self.assertListEqual(src_seq == fitted_src) - self.assertListEqual(tgt_seq == fitted_tgt) + self.assertListEqual(src_seq, fitted_src) + self.assertListEqual(tgt_seq, fitted_tgt) - def source_too_big_target_ok(self): + def test_source_too_big_target_ok(self): src_seq = [1, 2, 3, 4, 5, 6] tgt_seq = [1, 2] fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size) - self.assertListEqual(src_seq == [1, 2, 3, 4, 5]) - self.assertListEqual(tgt_seq == fitted_tgt) + self.assertListEqual(fitted_src, [1, 2, 3, 4, 5]) + self.assertListEqual(fitted_tgt, fitted_tgt) - def target_too_big_source_ok(self): + def test_target_too_big_source_ok(self): src_seq = [1, 2, 3, 4] tgt_seq = [1, 2, 3, 4] fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size) - self.assertListEqual(src_seq == src_seq) - self.assertListEqual(tgt_seq == [1, 2, 3]) + self.assertListEqual(fitted_src, src_seq) + self.assertListEqual(fitted_tgt, [1, 2, 3]) - def source_and_target_too_big(self): + def test_source_and_target_too_big(self): src_seq = [1, 2, 3, 4, 5, 6, 7] tgt_seq = [1, 2, 3, 4, 5, 6, 7] fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size) - self.assertListEqual(src_seq == [1, 2, 3, 4, 5]) - self.assertListEqual(tgt_seq == [1, 2]) + self.assertListEqual(fitted_src, [1, 2, 3, 4, 5]) + self.assertListEqual(fitted_tgt, [1, 2]) if __name__ == "__main__":