resolve PR comments

This commit is contained in:
Rémi Louf
2019-10-29 17:10:20 +01:00
parent 4c3ac4a7d8
commit dfce409691
7 changed files with 647 additions and 473 deletions

View File

@@ -16,10 +16,9 @@
""" Finetuning seq2seq models for sequence generation."""
import argparse
from collections import deque
import functools
import logging
import os
import pickle
import random
import sys
@@ -29,7 +28,22 @@ import torch
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from transformers import AutoTokenizer, PreTrainedSeq2seq, Model2Model
from transformers import (
AutoTokenizer,
BertForMaskedLM,
BertConfig,
PreTrainedSeq2seq,
Model2Model,
)
from utils_summarization import (
CNNDailyMailDataset,
encode_for_summarization,
fit_to_block_size,
build_lm_labels,
build_mask,
compute_token_type_ids,
)
logger = logging.getLogger(__name__)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
@@ -46,194 +60,41 @@ def set_seed(args):
# ------------
class TextDataset(Dataset):
""" Abstracts the dataset used to train seq2seq models.
CNN/Daily News:
The CNN/Daily News raw datasets are downloaded from [1]. The stories are
stored in different files; the summary appears at the end of the story as
sentences that are prefixed by the special `@highlight` line. To process
the data, untar both datasets in the same folder, and pass the path to this
folder as the "data_dir argument. The formatting code was inspired by [2].
[1] https://cs.nyu.edu/~kcho/
[2] https://github.com/abisee/cnn-dailymail/
"""
def __init__(self, tokenizer, prefix="train", data_dir="", block_size=512):
assert os.path.isdir(data_dir)
# Load the features that have already been computed, if any
cached_features_file = os.path.join(
data_dir, "cached_lm_{}_{}".format(block_size, prefix)
)
if os.path.exists(cached_features_file):
logger.info("Loading features from cached file %s", cached_features_file)
with open(cached_features_file, "rb") as source:
self.examples = pickle.load(source)
return
logger.info("Creating features from dataset at %s", data_dir)
datasets = ["cnn", "dailymail"]
self.examples = {"source": [], "target": []}
for dataset in datasets:
path_to_stories = os.path.join(data_dir, dataset, "stories")
story_filenames_list = os.listdir(path_to_stories)
for story_filename in story_filenames_list:
path_to_story = os.path.join(path_to_stories, story_filename)
if not os.path.isfile(path_to_story):
continue
with open(path_to_story, encoding="utf-8") as source:
raw_story = source.read()
story_lines, summary_lines = process_story(raw_story)
if len(summary_lines) == 0 or len(story_lines) == 0:
continue
story_token_ids, summary_token_ids = _encode_for_summarization(
story_lines, summary_lines, tokenizer
)
story_seq = _fit_to_block_size(story_token_ids, block_size)
self.examples["source"].append(story_seq)
summary_seq = _fit_to_block_size(summary_token_ids, block_size)
self.examples["summary"].append(summary_seq)
logger.info("Saving features into cache file %s", cached_features_file)
with open(cached_features_file, "wb") as sink:
pickle.dump(self.examples, sink, protocol=pickle.HIGHEST_PROTOCOL)
def __len__(self):
return len(self.examples)
def __getitem__(self, items):
return (
torch.tensor(self.examples["source"][items]),
torch.tensor(self.examples["target"][items]),
)
def process_story(raw_story):
""" Extract the story and summary from a story file.
Attributes:
raw_story (str): content of the story file as an utf-8 encoded string.
Raises:
IndexError: If the stoy is empty or contains no highlights.
"""
nonempty_lines = list(
filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
)
# for some unknown reason some lines miss a period, add it
nonempty_lines = [_add_missing_period(line) for line in nonempty_lines]
# gather article lines
story_lines = []
lines = deque(nonempty_lines)
while True:
try:
element = lines.popleft()
if element.startswith("@highlight"):
break
story_lines.append(element)
except IndexError:
# if "@highlight" is absent from the file we pop
# all elements until there is None.
return story_lines, []
# gather summary lines
summary_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))
return story_lines, summary_lines
def _encode_for_summarization(story_lines, summary_lines, tokenizer):
""" Encode the story and summary lines, and join them
as specified in [1] by using `[SEP] [CLS]` tokens to separate
sentences.
"""
story_lines_token_ids = [
tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line))
for line in story_lines
]
summary_lines_token_ids = [
tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line))
for line in summary_lines
]
story_token_ids = [
token for sentence in story_lines_token_ids for token in sentence
]
summary_token_ids = [
token for sentence in summary_lines_token_ids for token in sentence
]
return story_token_ids, summary_token_ids
def _add_missing_period(line):
END_TOKENS = [".", "!", "?", "...", "'", "`", '"', u"\u2019", u"\u2019", ")"]
if line.startswith("@highlight"):
return line
if line[-1] in END_TOKENS:
return line
return line + "."
def _fit_to_block_size(sequence, block_size):
""" Adapt the source and target sequences' lengths to the block size.
If the sequence is shorter than the block size we pad it with -1 ids
which correspond to padding tokens.
"""
if len(sequence) > block_size:
return sequence[:block_size]
else:
sequence.extend([0] * (block_size - len(sequence)))
return sequence
def mask_padding_tokens(sequence):
""" Padding token, encoded as 0, are represented by the value -1 in the
masks """
padded = sequence.clone()
padded[padded == 0] = -1
return padded
def load_and_cache_examples(args, tokenizer):
dataset = TextDataset(tokenizer, data_dir=args.data_dir)
dataset = CNNDailyMailDataset(tokenizer, data_dir=args.data_dir)
return dataset
def compute_token_type_ids(batch, separator_token_id):
""" Segment embeddings as described in [1]
def collate(data, tokenizer, block_size):
""" List of tuple as an input. """
# remove the files with empty an story/summary, encode and fit to block
data = filter(lambda x: not (len(x[0]) == 0 or len(x[1]) == 0), data)
data = [
encode_for_summarization(story, summary, tokenizer) for story, summary in data
]
data = [
(
fit_to_block_size(story, block_size, tokenizer.pad_token_id),
fit_to_block_size(summary, block_size, tokenizer.pad_token_id),
)
for story, summary in data
]
The values {0,1} were found in the repository [2].
stories = torch.tensor([story for story, summary in data])
summaries = torch.tensor([summary for story, summary in data])
encoder_token_type_ids = compute_token_type_ids(stories, tokenizer.cls_token_id)
encoder_mask = build_mask(stories, tokenizer.pad_token_id)
decoder_mask = build_mask(summaries, tokenizer.pad_token_id)
lm_labels = build_lm_labels(summaries, tokenizer.pad_token_id)
Attributes:
batch: torch.Tensor, size [batch_size, block_size]
Batch of input.
separator_token_id: int
The value of the token that separates the segments.
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
arXiv preprint arXiv:1908.08345 (2019).
[2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217)
"""
batch_embeddings = []
sentence_num = 0
for sequence in batch:
embeddings = []
for s in sequence:
if s == separator_token_id:
sentence_num += 1
embeddings.append(sentence_num % 2)
batch_embeddings.append(embeddings)
return torch.tensor(batch_embeddings)
return (
stories,
summaries,
encoder_token_type_ids,
encoder_mask,
decoder_mask,
lm_labels,
)
# ----------
@@ -252,7 +113,7 @@ class BertSumOptimizer(object):
arXiv preprint arXiv:1908.08345 (2019).
"""
def __init__(self, model, lr, warmup_steps, beta_1=0.99, beta_2=0.999, eps=1e-9):
def __init__(self, model, lr, warmup_steps, beta_1=0.99, beta_2=0.999, eps=1e-8):
self.encoder = model.encoder
self.decoder = model.decoder
self.lr = lr
@@ -306,8 +167,12 @@ def train(args, model, tokenizer):
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
train_dataset = load_and_cache_examples(args, tokenizer)
train_sampler = RandomSampler(train_dataset)
model_collate_fn = functools.partial(collate, tokenizer=tokenizer, block_size=512)
train_dataloader = DataLoader(
train_dataset, sampler=train_sampler, batch_size=args.train_batch_size
train_dataset,
sampler=train_sampler,
batch_size=args.train_batch_size,
collate_fn=model_collate_fn,
)
# Training schedule
@@ -351,26 +216,23 @@ def train(args, model, tokenizer):
for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True)
for step, batch in enumerate(epoch_iterator):
source, target = batch
token_type_ids = compute_token_type_ids(source, tokenizer.cls_token_id)
labels_src = mask_padding_tokens(source)
labels_tgt = mask_padding_tokens(target)
source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch
source = source.to(args.device)
target = target.to(args.device)
token_type_ids = token_type_ids.to(args.device)
labels_src = labels_src.to(args.device)
labels_tgt = labels_tgt.to(args.device)
encoder_token_type_ids = encoder_token_type_ids.to(args.device)
encoder_mask = encoder_mask.to(args.device)
decoder_mask = decoder_mask.to(args.device)
lm_labels = lm_labels.to(args.device)
model.train()
outputs = model(
source,
target,
token_type_ids=token_type_ids,
decoder_encoder_attention_mask=labels_src,
decoder_attention_mask=labels_tgt,
decoder_lm_labels=labels_tgt,
decoder_initialize_randomly=True,
encoder_token_type_ids=encoder_token_type_ids,
encoder_attention_mask=encoder_mask,
decoder_attention_mask=decoder_mask,
decoder_lm_labels=lm_labels,
)
loss = outputs[0]
@@ -421,21 +283,23 @@ def evaluate(args, model, tokenizer, prefix=""):
model.eval()
for batch in tqdm(eval_dataloader, desc="Evaluating"):
source, target = batch
labels_src = mask_padding_tokens(source)
labels_tgt = mask_padding_tokens(target)
source.to(args.device)
target.to(args.device)
labels_src.to(args.device)
labels_tgt.to(args.device)
source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch
source = source.to(args.device)
target = target.to(args.device)
encoder_token_type_ids = encoder_token_type_ids.to(args.device)
encoder_mask = encoder_mask.to(args.device)
decoder_mask = decoder_mask.to(args.device)
lm_labels = lm_labels.to(args.device)
with torch.no_grad():
outputs = model(
source,
target,
decoder_encoder_attention_mask=labels_src,
decoder_attention_mask=labels_tgt,
decoder_lm_labels=labels_tgt,
encoder_token_type_ids=encoder_token_type_ids,
encoder_attention_mask=encoder_mask,
decoder_attention_mask=decoder_mask,
decoder_lm_labels=lm_labels,
)
lm_loss = outputs[0]
eval_loss += lm_loss.mean().item()
@@ -525,7 +389,7 @@ def main():
)
parser.add_argument(
"--num_train_epochs",
default=1,
default=10,
type=int,
help="Total number of training epochs to perform.",
)
@@ -558,9 +422,13 @@ def main():
args.device = torch.device("cuda")
args.n_gpu = torch.cuda.device_count()
# Load pretrained model and tokenizer
# Load pretrained model and tokenizer. The decoder's weights are randomly initialized.
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
model = Model2Model.from_pretrained(args.model_name_or_path)
config = BertConfig.from_pretrained(args.model_name_or_path)
decoder_model = BertForMaskedLM(config)
model = Model2Model.from_pretrained(
args.model_name_or_path, decoder_model=decoder_model
)
# Setup logging
logging.basicConfig(