From 260ac7d9a8501f6c631adc355e269e7f3f6274f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Tue, 15 Oct 2019 12:24:35 +0200 Subject: [PATCH] wip commit, switching computers --- examples/run_seq2seq_finetuning.py | 42 ++++++++-------- examples/run_seq2seq_finetuning_test.py | 64 +++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 21 deletions(-) create mode 100644 examples/run_seq2seq_finetuning_test.py diff --git a/examples/run_seq2seq_finetuning.py b/examples/run_seq2seq_finetuning.py index 5d7da58a23..1f247ab25b 100644 --- a/examples/run_seq2seq_finetuning.py +++ b/examples/run_seq2seq_finetuning.py @@ -31,7 +31,7 @@ Natural Language Understanding and Generation.” (May 2019) ArXiv:1905.03197 """ import argparse -import dequeue +from collections import deque import logging import pickle import random @@ -57,9 +57,9 @@ class TextDataset(Dataset): CNN/Daily News: - The CNN/Daily News raw datasets are downloaded from [1]. They consist in stories stored - in different files where the summary sentences are indicated by the special `@highlight` token. - To process the data, untar both datasets in the same folder, and pass the path to this + The CNN/Daily News raw datasets are downloaded from [1]. The stories are stored in different files; the summary appears at the end of the story as + sentences that are prefixed by the special `@highlight` line. To process the + data, untar both datasets in the same folder, and pass the path to this folder as the "data_dir argument. The formatting code was inspired by [2]. [1] https://cs.nyu.edu/~kcho/ @@ -69,7 +69,7 @@ class TextDataset(Dataset): assert os.path.isdir(data_dir) # Load features that have already been computed if present - cached_features_file = os.path.join(directory, "cached_lm_{}_{}".format(block_size, data_dir) + cached_features_file = os.path.join(directory, "cached_lm_{}_{}".format(block_size, data_dir)) if os.path.exists(cached_features_file): logger.info("Loading features from cached file %s", cached_features_file) with open(cached_features_file, "rb") as source: @@ -86,18 +86,19 @@ class TextDataset(Dataset): stories_files = os.listdir(path_to_stories) for story_file in stories_files: path_to_story = os.path.join(path_to_stories, "story_file") - if !os.path.isfile(path_to_story): + if not os.path.isfile(path_to_story): continue with open(path_to_story, encoding="utf-8") as source: try: - story, summary = process_story(source) + raw_story = source.read() + story, summary = process_story(raw_story) except IndexError: continue story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story)) summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary)) - story_seq, summary_seq = _fit_to_block_size(story, summary, blocksize) + story_seq, summary_seq = _fit_to_block_size(story, summary, block_size) example = tokenizer.add_special_token_sequence_pair(story_seq, summary_seq) self.examples.append(example) @@ -108,22 +109,22 @@ class TextDataset(Dataset): def __len__(self): return len(self.examples) - def __getitem__(self): + def __getitem__(self, items): return torch.tensor(self.examples[items]) -def process_story(story_file): +def process_story(raw_story): """ Process the text contained in a story file. Returns the story and the summary """ - file_lines = list(filter(lambda x: len(x)!=0, [line.strip() for lines in story_file])) + file_lines = list(filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])) # for some unknown reason some lines miss a period, add it file_lines = [_add_missing_period(line) for line in file_lines] # gather article lines story_lines = [] - lines = dequeue(file_lines) + lines = deque(file_lines) while True: try: element = lines.popleft() @@ -134,7 +135,7 @@ def process_story(story_file): raise ie # gather summary lines - highlights_lines = list(filter(lambda t: !t.startswith("@highlight"), lines)) + highlights_lines = list(filter(lambda t: not t.startswith("@highlight"), lines)) # join the lines story = " ".join(story_lines) @@ -145,7 +146,7 @@ def process_story(story_file): def _add_missing_period(line): END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', u'\u2019', u'\u2019', ")"] - if line == "@highlight": + if line.startswith("@highlight"): return line if line[-1] in END_TOKENS: return line @@ -163,8 +164,8 @@ def _fit_to_block_size(src_sequence, tgt_sequence, block_size): [1] Dong, Li, et al. "Unified Language Model Pre-training for Natural Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019). """ - SRC_MAX_LENGTH = int(0.75 * block_size) - 2 # CLS and EOS token - TGT_MAX_LENGTH = block_size - SRC_MAX_LENGTH - 1 # EOS token + SRC_MAX_LENGTH = int(0.75 * block_size) - 2 # CLS and EOS token + TGT_MAX_LENGTH = block_size - SRC_MAX_LENGTH - 1 # EOS token # we dump the examples that are too small to fit in the block size for the # sake of simplicity. You can modify this by adding model-specific padding. @@ -172,22 +173,21 @@ def _fit_to_block_size(src_sequence, tgt_sequence, block_size): return None # the source sequence has `[SEP_i]` special tokens with i \in [0,9]. We keep them for now. - if len(src_sequence) > SRC_MAX_LENGTH + if len(src_sequence) > SRC_MAX_LENGTH: if len(tgt_sequence) > TGT_MAX_LENGTH: src_sequence = src_sequence[:SRC_MAX_LENGTH] tgt_sequence = tgt_sequence[:TGT_MAX_LENGTH] else: src_sequence = src_sequence[block_size - len(tgt_sequence) - 3] else: - if len(tgt_tokens) > TGT_MAX_LENGTH: + if len(tgt_sequence) > TGT_MAX_LENGTH: tgt_sequence = tgt_sequence[block_size - len(src_sequence) - 3] return src_sequence, tgt_sequence - def load_and_cache_examples(args, tokenizer): - dataset = TextDataset(tokenizer, file_path=args.train_data_file) + dataset = TextDataset(tokenizer, file_path=args.data_dir) return dataset @@ -200,7 +200,7 @@ def main(): parser = argparse.ArgumentParser() # Required parameters - parser.add_argument("--train_data_file", + parser.add_argument("--data_dir", default=None, type=str, required=True, diff --git a/examples/run_seq2seq_finetuning_test.py b/examples/run_seq2seq_finetuning_test.py new file mode 100644 index 0000000000..34d9add10d --- /dev/null +++ b/examples/run_seq2seq_finetuning_test.py @@ -0,0 +1,64 @@ +# coding=utf-8 +# Copyright 2019 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from .run_seq2seq_finetuning import process_story, _fit_to_block_size + + +class DataLoaderTest(unittest.TestCase): + def __init__(self, block_size=10): + self.block_size = block_size + + def source_and_target_too_small(self): + """ When the sum of the lengths of the source and target sequences is + smaller than the block size (minus the number of special tokens), skip the example. """ + src_seq = [1, 2, 3, 4] + tgt_seq = [5, 6] + self.assertEqual(_fit_to_block_size(src_seq, tgt_seq, self.block_size), None) + + def source_and_target_fit_exactly(self): + """ When the sum of the lengths of the source and target sequences is + equal to the block size (minus the number of special tokens), return the + sequences unchanged. """ + src_seq = [1, 2, 3, 4] + tgt_seq = [5, 6, 7] + fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size) + self.assertListEqual(src_seq == fitted_src) + self.assertListEqual(tgt_seq == fitted_tgt) + + def source_too_big_target_ok(self): + src_seq = [1, 2, 3, 4, 5, 6] + tgt_seq = [1, 2] + fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size) + self.assertListEqual(src_seq == [1, 2, 3, 4, 5]) + self.assertListEqual(tgt_seq == fitted_tgt) + + def target_too_big_source_ok(self): + src_seq = [1, 2, 3, 4] + tgt_seq = [1, 2, 3, 4] + fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size) + self.assertListEqual(src_seq == src_seq) + self.assertListEqual(tgt_seq == [1, 2, 3]) + + def source_and_target_too_big(self): + src_seq = [1, 2, 3, 4, 5, 6, 7] + tgt_seq = [1, 2, 3, 4, 5, 6, 7] + fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size) + self.assertListEqual(src_seq == [1, 2, 3, 4, 5]) + self.assertListEqual(tgt_seq == [1, 2]) + + +if __name__ == "__main__": + unittest.main()