diff --git a/examples/run_seq2seq_finetuning.py b/examples/run_seq2seq_finetuning.py index 4a7042929e..5d7da58a23 100644 --- a/examples/run_seq2seq_finetuning.py +++ b/examples/run_seq2seq_finetuning.py @@ -53,20 +53,14 @@ def set_seed(args): class TextDataset(Dataset): - """ Abstracts a dataset used to train seq2seq models. - - A seq2seq dataset consists of two files: - - The source file that contains the source sequences, one line per sequence; - - The target file contains the target sequences, one line per sequence. - - The matching betwen source and target sequences is made on the basis of line numbers. + """ Abstracts the dataset used to train seq2seq models. CNN/Daily News: The CNN/Daily News raw datasets are downloaded from [1]. They consist in stories stored in different files where the summary sentences are indicated by the special `@highlight` token. - To process the data, untar both datasets in the same folder, and path the path to this - folder as the "train_data_file" argument. The formatting code was inspired by [2]. + To process the data, untar both datasets in the same folder, and pass the path to this + folder as the "data_dir argument. The formatting code was inspired by [2]. [1] https://cs.nyu.edu/~kcho/ [2] https://github.com/abisee/cnn-dailymail/ @@ -82,9 +76,8 @@ class TextDataset(Dataset): self.examples = pickle.load(source) return - logger.info("Creating features from dataset at %s", directory) + logger.info("Creating features from dataset at %s", data_dir) - # we need to iterate over both the cnn and the dailymail dataset datasets = ['cnn', 'dailymail'] for dataset in datasets: path_to_stories = os.path.join(data_dir, dataset, "stories") @@ -102,9 +95,10 @@ class TextDataset(Dataset): except IndexError: continue - src_sequence = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story)) - tgt_sequence = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary)) - example = _truncate_and_concatenate(src_sequence, tgt_sequence, blocksize) + story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story)) + summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary)) + story_seq, summary_seq = _fit_to_block_size(story, summary, blocksize) + example = tokenizer.add_special_token_sequence_pair(story_seq, summary_seq) self.examples.append(example) logger.info("Saving features into cache file %s", cached_features_file) @@ -158,15 +152,13 @@ def _add_missing_period(line): return line + " ." -def _truncate_and_concatenate(src_sequence, tgt_sequence, block_size): +def _fit_to_block_size(src_sequence, tgt_sequence, block_size): """ Concatenate the sequences and adapt their lengths to the block size. - Following [1] we perform the following transformations: - - Add an [CLS] token at the beginning of the source sequence; - - Add an [EOS] token at the end of the source and target sequences; - - Concatenate the source and target + tokens sequence. If the concatenated sequence is - longer than 512 we follow the 75%/25% rule in [1]: limit the source sequence's length to 384 - and the target sequence's length to 128. + Following [1] we truncate the source and target + tokens sequences so they fit + in the block size. If the concatenated sequence is longer than 512 we follow + the 75%/25% rule in [1]: limit the source sequence's length to 384 and the + target sequence's length to 128. [1] Dong, Li, et al. "Unified Language Model Pre-training for Natural Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019). @@ -176,22 +168,21 @@ def _truncate_and_concatenate(src_sequence, tgt_sequence, block_size): # we dump the examples that are too small to fit in the block size for the # sake of simplicity. You can modify this by adding model-specific padding. - if len(src_tokens) + len(src_tokens) + 3 < block_size: + if len(src_sequence) + len(src_sequence) + 3 < block_size: return None # the source sequence has `[SEP_i]` special tokens with i \in [0,9]. We keep them for now. - if len(src_tokens) > SRC_MAX_LENGTH - if len(tgt_tokens) > TGT_MAX_LENGTH: - src_tokens = src_tokens[:SRC_MAX_LENGTH] - tgt_tokens = tgt_tokens[:TGT_MAX_LENGTH] + if len(src_sequence) > SRC_MAX_LENGTH + if len(tgt_sequence) > TGT_MAX_LENGTH: + src_sequence = src_sequence[:SRC_MAX_LENGTH] + tgt_sequence = tgt_sequence[:TGT_MAX_LENGTH] else: - src_tokens = src_tokens[block_size - len(tgt_tokens) - 3] + src_sequence = src_sequence[block_size - len(tgt_sequence) - 3] else: if len(tgt_tokens) > TGT_MAX_LENGTH: - tgt_tokens = tgt_tokens[block_size - len(src_tokens) - 3] + tgt_sequence = tgt_sequence[block_size - len(src_sequence) - 3] - # I add the special tokens manually, but this should be done by the tokenizer. That's the next step. - return ["[CLS]"] + src_tokens + ["[EOS]"] + tgt_tokens + ["[EOS]"] + return src_sequence, tgt_sequence @@ -250,4 +241,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main()