truncation function is fully tested
This commit is contained in:
@@ -41,7 +41,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from transformers import BertConfig, Bert2Rnd, BertTokenizer
|
||||
from transformers import BertTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -57,19 +57,23 @@ class TextDataset(Dataset):
|
||||
|
||||
CNN/Daily News:
|
||||
|
||||
The CNN/Daily News raw datasets are downloaded from [1]. The stories are stored in different files; the summary appears at the end of the story as
|
||||
sentences that are prefixed by the special `@highlight` line. To process the
|
||||
data, untar both datasets in the same folder, and pass the path to this
|
||||
The CNN/Daily News raw datasets are downloaded from [1]. The stories are
|
||||
stored in different files; the summary appears at the end of the story as
|
||||
sentences that are prefixed by the special `@highlight` line. To process
|
||||
the data, untar both datasets in the same folder, and pass the path to this
|
||||
folder as the "data_dir argument. The formatting code was inspired by [2].
|
||||
|
||||
[1] https://cs.nyu.edu/~kcho/
|
||||
[2] https://github.com/abisee/cnn-dailymail/
|
||||
"""
|
||||
def __init_(self, tokenizer, data_dir='', block_size=512):
|
||||
|
||||
def __init_(self, tokenizer, data_dir="", block_size=512):
|
||||
assert os.path.isdir(data_dir)
|
||||
|
||||
# Load features that have already been computed if present
|
||||
cached_features_file = os.path.join(directory, "cached_lm_{}_{}".format(block_size, data_dir))
|
||||
cached_features_file = os.path.join(
|
||||
data_dir, "cached_lm_{}_{}".format(block_size, data_dir)
|
||||
)
|
||||
if os.path.exists(cached_features_file):
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
with open(cached_features_file, "rb") as source:
|
||||
@@ -78,7 +82,7 @@ class TextDataset(Dataset):
|
||||
|
||||
logger.info("Creating features from dataset at %s", data_dir)
|
||||
|
||||
datasets = ['cnn', 'dailymail']
|
||||
datasets = ["cnn", "dailymail"]
|
||||
for dataset in datasets:
|
||||
path_to_stories = os.path.join(data_dir, dataset, "stories")
|
||||
assert os.path.isdir(path_to_stories)
|
||||
@@ -99,7 +103,9 @@ class TextDataset(Dataset):
|
||||
story = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(story))
|
||||
summary = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(summary))
|
||||
story_seq, summary_seq = _fit_to_block_size(story, summary, block_size)
|
||||
example = tokenizer.add_special_token_sequence_pair(story_seq, summary_seq)
|
||||
example = tokenizer.add_special_token_sequence_pair(
|
||||
story_seq, summary_seq
|
||||
)
|
||||
self.examples.append(example)
|
||||
|
||||
logger.info("Saving features into cache file %s", cached_features_file)
|
||||
@@ -117,7 +123,9 @@ def process_story(raw_story):
|
||||
""" Process the text contained in a story file.
|
||||
Returns the story and the summary
|
||||
"""
|
||||
file_lines = list(filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]))
|
||||
file_lines = list(
|
||||
filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
|
||||
)
|
||||
|
||||
# for some unknown reason some lines miss a period, add it
|
||||
file_lines = [_add_missing_period(line) for line in file_lines]
|
||||
@@ -145,7 +153,7 @@ def process_story(raw_story):
|
||||
|
||||
|
||||
def _add_missing_period(line):
|
||||
END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', u'\u2019', u'\u2019', ")"]
|
||||
END_TOKENS = [".", "!", "?", "...", "'", "`", '"', u"\u2019", u"\u2019", ")"]
|
||||
if line.startswith("@highlight"):
|
||||
return line
|
||||
if line[-1] in END_TOKENS:
|
||||
@@ -154,34 +162,35 @@ def _add_missing_period(line):
|
||||
|
||||
|
||||
def _fit_to_block_size(src_sequence, tgt_sequence, block_size):
|
||||
""" Concatenate the sequences and adapt their lengths to the block size.
|
||||
""" Adapt the source and target sequences' lengths to the block size.
|
||||
|
||||
Following [1] we truncate the source and target + tokens sequences so they fit
|
||||
in the block size. If the concatenated sequence is longer than 512 we follow
|
||||
the 75%/25% rule in [1]: limit the source sequence's length to 384 and the
|
||||
target sequence's length to 128.
|
||||
If the concatenated sequence (source + target + 3 special tokens) would be
|
||||
longer than the block size we use the 75% / 25% rule followed in [1]. For a
|
||||
block size of 512 this means limiting the source sequence's length to 384
|
||||
and the target sequence's length to 128.
|
||||
|
||||
[1] Dong, Li, et al. "Unified Language Model Pre-training for Natural
|
||||
Language Understanding and Generation." arXiv preprint arXiv:1905.03197 (2019).
|
||||
"""
|
||||
SRC_MAX_LENGTH = int(0.75 * block_size) - 2 # CLS and EOS token
|
||||
TGT_MAX_LENGTH = block_size - SRC_MAX_LENGTH - 1 # EOS token
|
||||
TGT_MAX_LENGTH = block_size - (SRC_MAX_LENGTH + 2) - 1 # EOS token
|
||||
|
||||
# we dump the examples that are too small to fit in the block size for the
|
||||
# We dump the examples that are too small to fit in the block size for the
|
||||
# sake of simplicity. You can modify this by adding model-specific padding.
|
||||
if len(src_sequence) + len(src_sequence) + 3 < block_size:
|
||||
if len(src_sequence) + len(tgt_sequence) + 3 < block_size:
|
||||
return None
|
||||
|
||||
# the source sequence has `[SEP_i]` special tokens with i \in [0,9]. We keep them for now.
|
||||
if len(src_sequence) > SRC_MAX_LENGTH:
|
||||
if len(tgt_sequence) > TGT_MAX_LENGTH:
|
||||
src_sequence = src_sequence[:SRC_MAX_LENGTH]
|
||||
tgt_sequence = tgt_sequence[:TGT_MAX_LENGTH]
|
||||
else:
|
||||
src_sequence = src_sequence[block_size - len(tgt_sequence) - 3]
|
||||
remain_size = block_size - len(tgt_sequence) - 3
|
||||
src_sequence = src_sequence[:remain_size]
|
||||
else:
|
||||
if len(tgt_sequence) > TGT_MAX_LENGTH:
|
||||
tgt_sequence = tgt_sequence[block_size - len(src_sequence) - 3]
|
||||
remain_size = block_size - len(src_sequence) - 3
|
||||
tgt_sequence = tgt_sequence[:remain_size]
|
||||
|
||||
return src_sequence, tgt_sequence
|
||||
|
||||
@@ -200,44 +209,50 @@ def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument("--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input training data file (a text file).")
|
||||
parser.add_argument("--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input training data file (a text file).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
|
||||
# Optional parameters
|
||||
parser.add_argument("--model_name_or_path",
|
||||
default="bert-base-cased",
|
||||
type=str,
|
||||
help="The model checkpoint for weights initialization.")
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default="bert-base-cased",
|
||||
type=str,
|
||||
help="The model checkpoint for weights initialization.",
|
||||
)
|
||||
parser.add_argument("--seed", default=42, type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set up training device
|
||||
device = torch.device("cpu")
|
||||
# device = torch.device("cpu")
|
||||
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
config_class, model_class, tokenizer_class = BertConfig, Bert2Rnd, BertTokenizer
|
||||
config = config_class.from_pretrained(args.model_name_or_path)
|
||||
tokenizer_class = BertTokenizer
|
||||
# config = config_class.from_pretrained(args.model_name_or_path)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
|
||||
model = model_class.from_pretrained(args.model_name_or_path, config=config)
|
||||
model.to(device)
|
||||
# model = model_class.from_pretrained(args.model_name_or_path, config=config)
|
||||
# model.to(device)
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Training
|
||||
train_dataset = load_and_cache_examples(args, tokenizer)
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
_ = load_and_cache_examples(args, tokenizer)
|
||||
# global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||
# logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -14,50 +14,50 @@
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
from .run_seq2seq_finetuning import process_story, _fit_to_block_size
|
||||
from run_seq2seq_finetuning import _fit_to_block_size
|
||||
|
||||
|
||||
class DataLoaderTest(unittest.TestCase):
|
||||
def __init__(self, block_size=10):
|
||||
self.block_size = block_size
|
||||
def setUp(self):
|
||||
self.block_size = 10
|
||||
|
||||
def source_and_target_too_small(self):
|
||||
def test_source_and_target_too_small(self):
|
||||
""" When the sum of the lengths of the source and target sequences is
|
||||
smaller than the block size (minus the number of special tokens), skip the example. """
|
||||
src_seq = [1, 2, 3, 4]
|
||||
tgt_seq = [5, 6]
|
||||
self.assertEqual(_fit_to_block_size(src_seq, tgt_seq, self.block_size), None)
|
||||
|
||||
def source_and_target_fit_exactly(self):
|
||||
def test_source_and_target_fit_exactly(self):
|
||||
""" When the sum of the lengths of the source and target sequences is
|
||||
equal to the block size (minus the number of special tokens), return the
|
||||
sequences unchanged. """
|
||||
src_seq = [1, 2, 3, 4]
|
||||
tgt_seq = [5, 6, 7]
|
||||
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
|
||||
self.assertListEqual(src_seq == fitted_src)
|
||||
self.assertListEqual(tgt_seq == fitted_tgt)
|
||||
self.assertListEqual(src_seq, fitted_src)
|
||||
self.assertListEqual(tgt_seq, fitted_tgt)
|
||||
|
||||
def source_too_big_target_ok(self):
|
||||
def test_source_too_big_target_ok(self):
|
||||
src_seq = [1, 2, 3, 4, 5, 6]
|
||||
tgt_seq = [1, 2]
|
||||
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
|
||||
self.assertListEqual(src_seq == [1, 2, 3, 4, 5])
|
||||
self.assertListEqual(tgt_seq == fitted_tgt)
|
||||
self.assertListEqual(fitted_src, [1, 2, 3, 4, 5])
|
||||
self.assertListEqual(fitted_tgt, fitted_tgt)
|
||||
|
||||
def target_too_big_source_ok(self):
|
||||
def test_target_too_big_source_ok(self):
|
||||
src_seq = [1, 2, 3, 4]
|
||||
tgt_seq = [1, 2, 3, 4]
|
||||
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
|
||||
self.assertListEqual(src_seq == src_seq)
|
||||
self.assertListEqual(tgt_seq == [1, 2, 3])
|
||||
self.assertListEqual(fitted_src, src_seq)
|
||||
self.assertListEqual(fitted_tgt, [1, 2, 3])
|
||||
|
||||
def source_and_target_too_big(self):
|
||||
def test_source_and_target_too_big(self):
|
||||
src_seq = [1, 2, 3, 4, 5, 6, 7]
|
||||
tgt_seq = [1, 2, 3, 4, 5, 6, 7]
|
||||
fitted_src, fitted_tgt = _fit_to_block_size(src_seq, tgt_seq, self.block_size)
|
||||
self.assertListEqual(src_seq == [1, 2, 3, 4, 5])
|
||||
self.assertListEqual(tgt_seq == [1, 2])
|
||||
self.assertListEqual(fitted_src, [1, 2, 3, 4, 5])
|
||||
self.assertListEqual(fitted_tgt, [1, 2])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user