wip commit, switching computers
This commit is contained in:
@@ -31,7 +31,7 @@ Natural Language Understanding and Generation.” (May 2019) ArXiv:1905.03197
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import dequeue
|
from collections import deque
|
||||||
import logging
|
import logging
|
||||||
import pickle
|
import pickle
|
||||||
import random
|
import random
|
||||||
@@ -57,9 +57,9 @@ class TextDataset(Dataset):
|
|||||||
|
|
||||||
CNN/Daily News:
|
CNN/Daily News:
|
||||||
|
|
||||||
The CNN/Daily News raw datasets are downloaded from [1]. They consist in stories stored
|
The CNN/Daily News raw datasets are downloaded from [1]. The stories are stored in different files; the summary appears at the end of the story as
|
||||||
in different files where the summary sentences are indicated by the special `@highlight` token.
|
sentences that are prefixed by the special `@highlight` line. To process the
|
||||||
To process the data, untar both datasets in the same folder, and pass the path to this
|
data, untar both datasets in the same folder, and pass the path to this
|
||||||
folder as the "data_dir argument. The formatting code was inspired by [2].
|
folder as the "data_dir argument. The formatting code was inspired by [2].
|
||||||
|
|
||||||
[1] https://cs.nyu.edu/~kcho/
|
[1] https://cs.nyu.edu/~kcho/
|
||||||
@@ -69,7 +69,7 @@ class TextDataset(Dataset):
|
|||||||
assert os.path.isdir(data_dir)
|
assert os.path.isdir(data_dir)
|
||||||
|
|
||||||
# Load features that have already been computed if present
|
# Load features that have already been computed if present
|
||||||
cached_features_file = os.path.join(directory, "cached_lm_{}_{}".format(block_size, data_dir)
|
cached_features_file = os.path.join(directory, "cached_lm_{}_{}".format(block_size, data_dir))
|
||||||
if os.path.exists(cached_features_file):
|
if os.path.exists(cached_features_file):
|
||||||
logger.info("Loading features from cached file %s", cached_features_file)
|
logger.info("Loading features from cached file %s", cached_features_file)
|
||||||
with open(cached_features_file, "rb") as source:
|
with open(cached_features_file, "rb") as source:
|
||||||
@@ -86,18 +86,19 @@ class TextDataset(Dataset):
|
|||||||
stories_files = os.listdir(path_to_stories)
|
stories_files = os.listdir(path_to_stories)
|
||||||
for story_file in stories_files:
|
for story_file in stories_files:
|
||||||
path_to_story = os.path.join(path_to_stories, "story_file")
|
path_to_story = os.path.join(path_to_stories, "story_file")
|
||||||
if !os.path.isfile(path_to_story):
|
if not os.path.isfile(path_to_story):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
with open(path_to_story, encoding="utf-8") as source:
|
with open(path_to_story, encoding="utf-8") as source:
|
||||||
try:
|
try:
|
||||||
story, summary = process_story(source)
|
raw_story = source.read()
|
||||||
|
story, summary = process_story(raw_story)
|
||||||
except IndexError:
|
except IndexError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story))
|
story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story))
|
||||||
summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary))
|
summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary))
|
||||||
story_seq, summary_seq = _fit_to_block_size(story, summary, blocksize)
|
story_seq, summary_seq = _fit_to_block_size(story, summary, block_size)
|
||||||
example = tokenizer.add_special_token_sequence_pair(story_seq, summary_seq)
|
example = tokenizer.add_special_token_sequence_pair(story_seq, summary_seq)
|
||||||
self.examples.append(example)
|
self.examples.append(example)
|
||||||
|
|
||||||
@@ -108,22 +109,22 @@ class TextDataset(Dataset):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.examples)
|
return len(self.examples)
|
||||||
|
|
||||||
def __getitem__(self):
|
def __getitem__(self, items):
|
||||||
return torch.tensor(self.examples[items])
|
return torch.tensor(self.examples[items])
|
||||||
|
|
||||||
|
|
||||||
def process_story(story_file):
|
def process_story(raw_story):
|
||||||
""" Process the text contained in a story file.
|
""" Process the text contained in a story file.
|
||||||
Returns the story and the summary
|
Returns the story and the summary
|
||||||
"""
|
"""
|
||||||
file_lines = list(filter(lambda x: len(x)!=0, [line.strip() for lines in story_file]))
|
file_lines = list(filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]))
|
||||||
|
|
||||||
# for some unknown reason some lines miss a period, add it
|
# for some unknown reason some lines miss a period, add it
|
||||||
file_lines = [_add_missing_period(line) for line in file_lines]
|
file_lines = [_add_missing_period(line) for line in file_lines]
|
||||||
|
|
||||||
# gather article lines
|
# gather article lines
|
||||||
story_lines = []
|
story_lines = []
|
||||||
lines = dequeue(file_lines)
|
lines = deque(file_lines)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
element = lines.popleft()
|
element = lines.popleft()
|
||||||
@@ -134,7 +135,7 @@ def process_story(story_file):
|
|||||||
raise ie
|
raise ie
|
||||||
|
|
||||||
# gather summary lines
|
# gather summary lines
|
||||||
highlights_lines = list(filter(lambda t: !t.startswith("@highlight"), lines))
|
highlights_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))
|
||||||
|
|
||||||
# join the lines
|
# join the lines
|
||||||
story = " ".join(story_lines)
|
story = " ".join(story_lines)
|
||||||
@@ -145,7 +146,7 @@ def process_story(story_file):
|
|||||||
|
|
||||||
def _add_missing_period(line):
|
def _add_missing_period(line):
|
||||||
END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', u'\u2019', u'\u2019', ")"]
|
END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', u'\u2019', u'\u2019', ")"]
|
||||||
if line == "@highlight":
|
if line.startswith("@highlight"):
|
||||||
return line
|
return line
|
||||||
if line[-1] in END_TOKENS:
|
if line[-1] in END_TOKENS:
|
||||||
return line
|
return line
|
||||||
@@ -163,8 +164,8 @@ def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
|
|||||||
[1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
|
[1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
|
||||||
Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
|
Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
|
||||||
"""
|
"""
|
||||||
SRC_MAX_LENGTH = int(0.75 * block_size) - 2 # CLS and EOS token
|
SRC_MAX_LENGTH = int(0.75 * block_size) - 2 # CLS and EOS token
|
||||||
TGT_MAX_LENGTH = block_size - SRC_MAX_LENGTH - 1 # EOS token
|
TGT_MAX_LENGTH = block_size - SRC_MAX_LENGTH - 1 # EOS token
|
||||||
|
|
||||||
# we dump the examples that are too small to fit in the block size for the
|
# we dump the examples that are too small to fit in the block size for the
|
||||||
# sake of simplicity. You can modify this by adding model-specific padding.
|
# sake of simplicity. You can modify this by adding model-specific padding.
|
||||||
@@ -172,22 +173,21 @@ def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# the source sequence has `[SEP_i]` special tokens with i \in [0,9]. We keep them for now.
|
# the source sequence has `[SEP_i]` special tokens with i \in [0,9]. We keep them for now.
|
||||||
if len(src_sequence) > SRC_MAX_LENGTH
|
if len(src_sequence) > SRC_MAX_LENGTH:
|
||||||
if len(tgt_sequence) > TGT_MAX_LENGTH:
|
if len(tgt_sequence) > TGT_MAX_LENGTH:
|
||||||
src_sequence = src_sequence[:SRC_MAX_LENGTH]
|
src_sequence = src_sequence[:SRC_MAX_LENGTH]
|
||||||
tgt_sequence = tgt_sequence[:TGT_MAX_LENGTH]
|
tgt_sequence = tgt_sequence[:TGT_MAX_LENGTH]
|
||||||
else:
|
else:
|
||||||
src_sequence = src_sequence[block_size - len(tgt_sequence) - 3]
|
src_sequence = src_sequence[block_size - len(tgt_sequence) - 3]
|
||||||
else:
|
else:
|
||||||
if len(tgt_tokens) > TGT_MAX_LENGTH:
|
if len(tgt_sequence) > TGT_MAX_LENGTH:
|
||||||
tgt_sequence = tgt_sequence[block_size - len(src_sequence) - 3]
|
tgt_sequence = tgt_sequence[block_size - len(src_sequence) - 3]
|
||||||
|
|
||||||
return src_sequence, tgt_sequence
|
return src_sequence, tgt_sequence
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_and_cache_examples(args, tokenizer):
|
def load_and_cache_examples(args, tokenizer):
|
||||||
dataset = TextDataset(tokenizer, file_path=args.train_data_file)
|
dataset = TextDataset(tokenizer, file_path=args.data_dir)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
@@ -200,7 +200,7 @@ def main():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
# Required parameters
|
# Required parameters
|
||||||
parser.add_argument("--train_data_file",
|
parser.add_argument("--data_dir",
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
|
|||||||
64
examples/run_seq2seq_finetuning_test.py
Normal file
64
examples/run_seq2seq_finetuning_test.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2019 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from .run_seq2seq_finetuning import process_story, _fit_to_block_size
|
||||||
|
|
||||||
|
|
||||||
|
class DataLoaderTest(unittest.TestCase):
|
||||||
|
def __init__(self, block_size=10):
|
||||||
|
self.block_size = block_size
|
||||||
|
|
||||||
|
def source_and_target_too_small(self):
|
||||||
|
""" When the sum of the lengths of the source and target sequences is
|
||||||
|
smaller than the block size (minus the number of special tokens), skip the example. """
|
||||||
|
src_seq = [1, 2, 3, 4]
|
||||||
|
tgt_seq = [5, 6]
|
||||||
|
self.assertEqual(_fit_to_block_size(src_seq, tgt_seq, self.block_size), None)
|
||||||
|
|
||||||
|
def source_and_target_fit_exactly(self):
|
||||||
|
""" When the sum of the lengths of the source and target sequences is
|
||||||
|
equal to the block size (minus the number of special tokens), return the
|
||||||
|
sequences unchanged. """
|
||||||
|
src_seq = [1, 2, 3, 4]
|
||||||
|
tgt_seq = [5, 6, 7]
|
||||||
|
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
|
||||||
|
self.assertListEqual(src_seq == fitted_src)
|
||||||
|
self.assertListEqual(tgt_seq == fitted_tgt)
|
||||||
|
|
||||||
|
def source_too_big_target_ok(self):
|
||||||
|
src_seq = [1, 2, 3, 4, 5, 6]
|
||||||
|
tgt_seq = [1, 2]
|
||||||
|
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
|
||||||
|
self.assertListEqual(src_seq == [1, 2, 3, 4, 5])
|
||||||
|
self.assertListEqual(tgt_seq == fitted_tgt)
|
||||||
|
|
||||||
|
def target_too_big_source_ok(self):
|
||||||
|
src_seq = [1, 2, 3, 4]
|
||||||
|
tgt_seq = [1, 2, 3, 4]
|
||||||
|
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
|
||||||
|
self.assertListEqual(src_seq == src_seq)
|
||||||
|
self.assertListEqual(tgt_seq == [1, 2, 3])
|
||||||
|
|
||||||
|
def source_and_target_too_big(self):
|
||||||
|
src_seq = [1, 2, 3, 4, 5, 6, 7]
|
||||||
|
tgt_seq = [1, 2, 3, 4, 5, 6, 7]
|
||||||
|
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
|
||||||
|
self.assertListEqual(src_seq == [1, 2, 3, 4, 5])
|
||||||
|
self.assertListEqual(tgt_seq == [1, 2])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user