diff --git a/examples/run_seq2seq_finetuning.py b/examples/run_seq2seq_finetuning.py index 7941384506..4a7042929e 100644 --- a/examples/run_seq2seq_finetuning.py +++ b/examples/run_seq2seq_finetuning.py @@ -17,9 +17,9 @@ We use the procedure described in [1] to finetune models for sequence generation. Let S1 and S2 be the source and target sequence respectively; we -pack them using the start of sequence [SOS] and end of sequence [EOS] token: +pack them using the start of sequence [EOS] and end of sequence [EOS] token: - [SOS] S1 [EOS] S2 [EOS] + [CLS] S1 [EOS] S2 [EOS] We then mask a fixed percentage of token from S2 at random and learn to predict the masked words. [EOS] can be masked during finetuning so the model learns to @@ -31,6 +31,7 @@ Natural Language Understanding and Generation.” (May 2019) ArXiv:1905.03197 """ import argparse +import dequeue import logging import pickle import random @@ -54,7 +55,7 @@ def set_seed(args): class TextDataset(Dataset): """ Abstracts a dataset used to train seq2seq models. - A seq2seq dataset consists in two files: + A seq2seq dataset consists of two files: - The source file that contains the source sequences, one line per sequence; - The target file contains the target sequences, one line per sequence. @@ -62,43 +63,53 @@ class TextDataset(Dataset): CNN/Daily News: - The CNN/Daily News dataset downloaded from [1] consists of two files that - respectively contain the stories and the associated summaries. Each line - corresponds to a different story. The files contain WordPiece tokens. + The CNN/Daily News raw datasets are downloaded from [1]. They consist in stories stored + in different files where the summary sentences are indicated by the special `@highlight` token. + To process the data, untar both datasets in the same folder, and path the path to this + folder as the "train_data_file" argument. The formatting code was inspired by [2]. - train.src: the longest story contains 6966 tokens, the shortest 12. - Sentences are separated with `[SEP_i]` where i is an int between 0 and 9. - - train.tgt: the longest summary contains 2467 tokens, the shortest 4. - Sentences are separated with `[X_SEP]` tokens. - - [1] https://github.com/microsoft/unilm + [1] https://cs.nyu.edu/~kcho/ + [2] https://github.com/abisee/cnn-dailymail/ """ - def __init_(self, tokenizer, src_path='train.src', target_path='target.src' block_size=512): - assert os.path.isfile(file_path) - directory, filename = os.path.split(file_path) + def __init_(self, tokenizer, data_dir='', block_size=512): + assert os.path.isdir(data_dir) - cached_features_file = os.path.join(directory, "cached_lm_{}_{}".format(block_size, file_name) + # Load features that have already been computed if present + cached_features_file = os.path.join(directory, "cached_lm_{}_{}".format(block_size, data_dir) if os.path.exists(cached_features_file): logger.info("Loading features from cached file %s", cached_features_file) with open(cached_features_file, "rb") as source: self.examples = pickle.load(source) - else: - logger.info("Creating features from dataset at %s", directory) + return - self.examples = [] - with open(src_path, encoding="utf-8") as source, open(target_path, encoding="utf-8") as target: - for line_src, line_tgt in zip(source, target) - src_sequence = line_src.read() - tgt_sequence = line_tgt.read() - example = _truncate_and_concatenate(src_sequence, tgt_sequence, block_size) - if example is not None: - example = tokenizer.convert_tokens_to_ids(example) - self.examples.append(example) + logger.info("Creating features from dataset at %s", directory) - logger.info("Saving features into cache file %s", cached_features_file) - with open(cached_features_file, "wb") as sink: - pickle.dump(self.examples, sink, protocole=pickle.HIGHEST_PROTOCOL) + # we need to iterate over both the cnn and the dailymail dataset + datasets = ['cnn', 'dailymail'] + for dataset in datasets: + path_to_stories = os.path.join(data_dir, dataset, "stories") + assert os.path.isdir(path_to_stories) + + stories_files = os.listdir(path_to_stories) + for story_file in stories_files: + path_to_story = os.path.join(path_to_stories, "story_file") + if !os.path.isfile(path_to_story): + continue + + with open(path_to_story, encoding="utf-8") as source: + try: + story, summary = process_story(source) + except IndexError: + continue + + src_sequence = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story)) + tgt_sequence = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary)) + example = _truncate_and_concatenate(src_sequence, tgt_sequence, blocksize) + self.examples.append(example) + + logger.info("Saving features into cache file %s", cached_features_file) + with open(cached_features_file, "wb") as sink: + pickle.dump(self.examples, sink, protocole=pickle.HIGHEST_PROTOCOL) def __len__(self): return len(self.examples) @@ -107,6 +118,46 @@ class TextDataset(Dataset): return torch.tensor(self.examples[items]) +def process_story(story_file): + """ Process the text contained in a story file. + Returns the story and the summary + """ + file_lines = list(filter(lambda x: len(x)!=0, [line.strip() for lines in story_file])) + + # for some unknown reason some lines miss a period, add it + file_lines = [_add_missing_period(line) for line in file_lines] + + # gather article lines + story_lines = [] + lines = dequeue(file_lines) + while True: + try: + element = lines.popleft() + if element.startswith("@highlight"): + break + story_lines.append(element) + except IndexError as ie: # if "@highlight" absent from file + raise ie + + # gather summary lines + highlights_lines = list(filter(lambda t: !t.startswith("@highlight"), lines)) + + # join the lines + story = " ".join(story_lines) + summary = " ".join(highlights_lines) + + return story, summary + + +def _add_missing_period(line): + END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', u'\u2019', u'\u2019', ")"] + if line == "@highlight": + return line + if line[-1] in END_TOKENS: + return line + return line + " ." + + def _truncate_and_concatenate(src_sequence, tgt_sequence, block_size): """ Concatenate the sequences and adapt their lengths to the block size. @@ -123,12 +174,6 @@ def _truncate_and_concatenate(src_sequence, tgt_sequence, block_size): SRC_MAX_LENGTH = int(0.75 * block_size) - 2 # CLS and EOS token TGT_MAX_LENGTH = block_size - SRC_MAX_LENGTH - 1 # EOS token - # the dataset contains special separator tokens that we remove for now. - # They are of the form `[SEP_i]` in the source file, and `[X_SEP]` in the - # target file. - src_tokens = list(filter(lambda t: "[SEP_" in t, src_sequence.split(" "))) - tgt_tokens = list(filter(lambda t: "_SEP]" in t, tgt_sequence.split(" "))) - # we dump the examples that are too small to fit in the block size for the # sake of simplicity. You can modify this by adding model-specific padding. if len(src_tokens) + len(src_tokens) + 3 < block_size: @@ -145,6 +190,7 @@ def _truncate_and_concatenate(src_sequence, tgt_sequence, block_size): if len(tgt_tokens) > TGT_MAX_LENGTH: tgt_tokens = tgt_tokens[block_size - len(src_tokens) - 3] + # I add the special tokens manually, but this should be done by the tokenizer. That's the next step. return ["[CLS]"] + src_tokens + ["[EOS]"] + tgt_tokens + ["[EOS]"]