here's one big commit
This commit is contained in:
@@ -393,7 +393,8 @@ This fine-tuned model is available as a checkpoint under the reference
|
||||
|
||||
## Seq2seq model fine-tuning
|
||||
|
||||
Based on the script [`run_seq2seq_finetuning.py`](https://github.com/huggingface/transformers/blob/master/examples/run_seq2seq_finetuning.py).
|
||||
Based on the script
|
||||
[`run_summarization_finetuning.py`](https://github.com/huggingface/transformers/blob/master/examples/run_summarization_finetuning.py).
|
||||
|
||||
Before running this script you should download **both** CNN and Daily Mail
|
||||
datasets from [Kyunghyun Cho's website](https://cs.nyu.edu/~kcho/DMQA/) (the
|
||||
@@ -412,7 +413,7 @@ archive.
|
||||
```bash
|
||||
export DATA_PATH=/path/to/dataset/
|
||||
|
||||
python run_seq2seq_finetuning.py \
|
||||
python run_summarization_finetuning.py \
|
||||
--output_dir=output \
|
||||
--model_type=bert2bert \
|
||||
--model_name_or_path=bert2bert \
|
||||
|
||||
@@ -1,361 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 The Microsoft Reseach team and The HuggingFace Inc. team.
|
||||
# Copyright (c) 2018 Microsoft and The HuggingFace Inc. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Finetuning seq2seq models for sequence generation."""
|
||||
|
||||
import argparse
|
||||
from collections import deque
|
||||
import logging
|
||||
import pickle
|
||||
import random
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm, trange
|
||||
import torch
|
||||
from torch.utils.data import Dataset, RandomSampler
|
||||
|
||||
from transformers import AutoTokenizer, Model2Model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
|
||||
# ------------
|
||||
# Load dataset
|
||||
# ------------
|
||||
|
||||
|
||||
class TextDataset(Dataset):
|
||||
""" Abstracts the dataset used to train seq2seq models.
|
||||
|
||||
CNN/Daily News:
|
||||
|
||||
The CNN/Daily News raw datasets are downloaded from [1]. The stories are
|
||||
stored in different files; the summary appears at the end of the story as
|
||||
sentences that are prefixed by the special `@highlight` line. To process
|
||||
the data, untar both datasets in the same folder, and pass the path to this
|
||||
folder as the "data_dir argument. The formatting code was inspired by [2].
|
||||
|
||||
[1] https://cs.nyu.edu/~kcho/
|
||||
[2] https://github.com/abisee/cnn-dailymail/
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer, prefix="train", data_dir="", block_size=512):
|
||||
assert os.path.isdir(data_dir)
|
||||
|
||||
# Load features that have already been computed if present
|
||||
cached_features_file = os.path.join(
|
||||
data_dir, "cached_lm_{}_{}".format(block_size, prefix)
|
||||
)
|
||||
if os.path.exists(cached_features_file):
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
with open(cached_features_file, "rb") as source:
|
||||
self.examples = pickle.load(source)
|
||||
return
|
||||
|
||||
logger.info("Creating features from dataset at %s", data_dir)
|
||||
self.examples = []
|
||||
datasets = ["cnn", "dailymail"]
|
||||
for dataset in datasets:
|
||||
path_to_stories = os.path.join(data_dir, dataset, "stories")
|
||||
assert os.path.isdir(path_to_stories)
|
||||
|
||||
story_filenames_list = os.listdir(path_to_stories)
|
||||
for story_filename in story_filenames_list:
|
||||
path_to_story = os.path.join(path_to_stories, story_filename)
|
||||
if not os.path.isfile(path_to_story):
|
||||
continue
|
||||
|
||||
with open(path_to_story, encoding="utf-8") as source:
|
||||
try:
|
||||
raw_story = source.read()
|
||||
story, summary = process_story(raw_story)
|
||||
except IndexError: # skip ill-formed stories
|
||||
continue
|
||||
|
||||
story = tokenizer.encode(story)
|
||||
story_seq = _fit_to_block_size(story, block_size)
|
||||
|
||||
summary = tokenizer.encode(summary)
|
||||
summary_seq = _fit_to_block_size(summary, block_size)
|
||||
|
||||
self.examples.append((story_seq, summary_seq))
|
||||
|
||||
logger.info("Saving features into cache file %s", cached_features_file)
|
||||
with open(cached_features_file, "wb") as sink:
|
||||
pickle.dump(self.examples, sink, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, items):
|
||||
return torch.tensor(self.examples[items])
|
||||
|
||||
|
||||
def process_story(raw_story):
|
||||
""" Extract the story and summary from a story file.
|
||||
|
||||
Attributes:
|
||||
raw_story (str): content of the story file as an utf-8 encoded string.
|
||||
|
||||
Raises:
|
||||
IndexError: If the stoy is empty or contains no highlights.
|
||||
"""
|
||||
file_lines = list(
|
||||
filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
|
||||
)
|
||||
|
||||
# for some unknown reason some lines miss a period, add it
|
||||
file_lines = [_add_missing_period(line) for line in file_lines]
|
||||
|
||||
# gather article lines
|
||||
story_lines = []
|
||||
lines = deque(file_lines)
|
||||
while True:
|
||||
try:
|
||||
element = lines.popleft()
|
||||
if element.startswith("@highlight"):
|
||||
break
|
||||
story_lines.append(element)
|
||||
except IndexError as ie: # if "@highlight" absent from file
|
||||
raise ie
|
||||
|
||||
# gather summary lines
|
||||
highlights_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))
|
||||
|
||||
# join the lines
|
||||
story = " ".join(story_lines)
|
||||
summary = " ".join(highlights_lines)
|
||||
|
||||
return story, summary
|
||||
|
||||
|
||||
def _add_missing_period(line):
|
||||
END_TOKENS = [".", "!", "?", "...", "'", "`", '"', u"\u2019", u"\u2019", ")"]
|
||||
if line.startswith("@highlight"):
|
||||
return line
|
||||
if line[-1] in END_TOKENS:
|
||||
return line
|
||||
return line + "."
|
||||
|
||||
|
||||
def _fit_to_block_size(sequence, block_size):
|
||||
""" Adapt the source and target sequences' lengths to the block size.
|
||||
If the sequence is shorter than the block size we pad it with -1 ids
|
||||
which correspond to padding tokens.
|
||||
"""
|
||||
if len(sequence) > block_size:
|
||||
return sequence[:block_size]
|
||||
else:
|
||||
sequence.extend([0] * (block_size - len(sequence)))
|
||||
return sequence
|
||||
|
||||
|
||||
def mask_padding_tokens(sequence):
|
||||
""" Replace the padding token with -1 values """
|
||||
return [s if s != 0 else -1 for s in sequence]
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer):
|
||||
dataset = TextDataset(tokenizer, data_dir=args.data_dir)
|
||||
return dataset
|
||||
|
||||
|
||||
# ------------
|
||||
# Train
|
||||
# ------------
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer):
|
||||
""" Fine-tune the pretrained model on the corpus. """
|
||||
|
||||
# Prepare the data loading
|
||||
args.train_bach_size = 1
|
||||
train_sampler = RandomSampler(train_dataset)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset, sampler=train_sampler, batch_size=args.train_bach_size
|
||||
)
|
||||
|
||||
# Prepare the optimizer and schedule (linear warmup and decay)
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
optimizer = AdamW(
|
||||
optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon
|
||||
)
|
||||
scheduler = WarmupLinearSchedule(
|
||||
optimizer, warmup_steps=args.warmup_steps, t_total=t_total
|
||||
)
|
||||
|
||||
# Train
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size
|
||||
* args.gradient_accumulation_steps
|
||||
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
global_step = 0
|
||||
tr_loss, logging_loss = 0.0, 0.0
|
||||
model.zero_grad()
|
||||
train_iterator = trange(args.num_train_epochs, desc="Epoch", disable=True)
|
||||
set_seed(args)
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True)
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
source = ([s for s, _ in batch]).to(args.device)
|
||||
target = ([t for _, t in batch]).to(args.device)
|
||||
model.train()
|
||||
outputs = model(source, target, decoder_lm_labels=mask_padding_tokens(target))
|
||||
loss = outputs[0]
|
||||
loss.backward()
|
||||
|
||||
tr_loss += loss.item()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
model.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
epoch_iterator.close()
|
||||
break
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
train_iterator.close()
|
||||
break
|
||||
|
||||
return global_step, tr_loss / global_step
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input training data file (a text file).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
|
||||
# Optional parameters
|
||||
parser.add_argument(
|
||||
"--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default="bert-base-cased",
|
||||
type=str,
|
||||
help="The model checkpoint to initialize the encoder and decoder's weights with.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default="bert",
|
||||
type=str,
|
||||
help="The decoder architecture to be fine-tuned.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
default=5e-5,
|
||||
type=float,
|
||||
help="The initial learning rate for Adam.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_train_epochs",
|
||||
default=1,
|
||||
type=int,
|
||||
help="Total number of training epochs to perform.",
|
||||
)
|
||||
parser.add_argument("--seed", default=42, type=int)
|
||||
parser.add_argument(
|
||||
"--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weight_decay", default=0.0, type=float, help="Weight deay if we apply some."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model_type != "bert":
|
||||
raise ValueError(
|
||||
"Only the BERT architecture is currently supported for seq2seq."
|
||||
)
|
||||
|
||||
# Set up training device
|
||||
# device = torch.device("cpu")
|
||||
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
||||
model = Model2Model.from_pretrained(args.model_name_or_path)
|
||||
# model.to(device)
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Training
|
||||
train_dataset = load_and_cache_examples(args, tokenizer)
|
||||
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
||||
# logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
620
examples/run_summarization_finetuning.py
Normal file
620
examples/run_summarization_finetuning.py
Normal file
@@ -0,0 +1,620 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 The HuggingFace Inc. team.
|
||||
# Copyright (c) 2019 The HuggingFace Inc. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Finetuning seq2seq models for sequence generation."""
|
||||
|
||||
import argparse
|
||||
from collections import deque
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm, trange
|
||||
import torch
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
|
||||
|
||||
from transformers import AutoTokenizer, PreTrainedSeq2seq, Model2Model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
|
||||
# ------------
|
||||
# Load dataset
|
||||
# ------------
|
||||
|
||||
|
||||
class TextDataset(Dataset):
|
||||
""" Abstracts the dataset used to train seq2seq models.
|
||||
|
||||
CNN/Daily News:
|
||||
|
||||
The CNN/Daily News raw datasets are downloaded from [1]. The stories are
|
||||
stored in different files; the summary appears at the end of the story as
|
||||
sentences that are prefixed by the special `@highlight` line. To process
|
||||
the data, untar both datasets in the same folder, and pass the path to this
|
||||
folder as the "data_dir argument. The formatting code was inspired by [2].
|
||||
|
||||
[1] https://cs.nyu.edu/~kcho/
|
||||
[2] https://github.com/abisee/cnn-dailymail/
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer, prefix="train", data_dir="", block_size=512):
|
||||
assert os.path.isdir(data_dir)
|
||||
|
||||
# Load the features that have already been computed, if any
|
||||
cached_features_file = os.path.join(
|
||||
data_dir, "cached_lm_{}_{}".format(block_size, prefix)
|
||||
)
|
||||
if os.path.exists(cached_features_file):
|
||||
logger.info("Loading features from cached file %s", cached_features_file)
|
||||
with open(cached_features_file, "rb") as source:
|
||||
self.examples = pickle.load(source)
|
||||
return
|
||||
|
||||
logger.info("Creating features from dataset at %s", data_dir)
|
||||
datasets = ["cnn", "dailymail"]
|
||||
|
||||
self.examples = {"source": [], "target": []}
|
||||
for dataset in datasets:
|
||||
path_to_stories = os.path.join(data_dir, dataset, "stories")
|
||||
story_filenames_list = os.listdir(path_to_stories)
|
||||
for story_filename in story_filenames_list:
|
||||
path_to_story = os.path.join(path_to_stories, story_filename)
|
||||
if not os.path.isfile(path_to_story):
|
||||
continue
|
||||
|
||||
with open(path_to_story, encoding="utf-8") as source:
|
||||
raw_story = source.read()
|
||||
story_lines, summary_lines = process_story(raw_story)
|
||||
if len(summary_lines) == 0 or len(story_lines) == 0:
|
||||
continue
|
||||
|
||||
story_token_ids, summary_token_ids = _encode_for_summarization(
|
||||
story_lines, summary_lines, tokenizer
|
||||
)
|
||||
story_seq = _fit_to_block_size(story_token_ids, block_size)
|
||||
self.examples["source"].append(story_seq)
|
||||
|
||||
summary_seq = _fit_to_block_size(summary_token_ids, block_size)
|
||||
self.examples["summary"].append(summary_seq)
|
||||
|
||||
logger.info("Saving features into cache file %s", cached_features_file)
|
||||
with open(cached_features_file, "wb") as sink:
|
||||
pickle.dump(self.examples, sink, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, items):
|
||||
return (
|
||||
torch.tensor(self.examples["source"][items]),
|
||||
torch.tensor(self.examples["target"][items]),
|
||||
)
|
||||
|
||||
|
||||
def process_story(raw_story):
|
||||
""" Extract the story and summary from a story file.
|
||||
|
||||
Attributes:
|
||||
raw_story (str): content of the story file as an utf-8 encoded string.
|
||||
|
||||
Raises:
|
||||
IndexError: If the stoy is empty or contains no highlights.
|
||||
"""
|
||||
nonempty_lines = list(
|
||||
filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
|
||||
)
|
||||
|
||||
# for some unknown reason some lines miss a period, add it
|
||||
nonempty_lines = [_add_missing_period(line) for line in nonempty_lines]
|
||||
|
||||
# gather article lines
|
||||
story_lines = []
|
||||
lines = deque(nonempty_lines)
|
||||
while True:
|
||||
try:
|
||||
element = lines.popleft()
|
||||
if element.startswith("@highlight"):
|
||||
break
|
||||
story_lines.append(element)
|
||||
except IndexError:
|
||||
# if "@highlight" is absent from the file we pop
|
||||
# all elements until there is None.
|
||||
return story_lines, []
|
||||
|
||||
# gather summary lines
|
||||
summary_lines = list(filter(lambda t: not t.startswith("@highlight"), lines))
|
||||
|
||||
return story_lines, summary_lines
|
||||
|
||||
|
||||
def _encode_for_summarization(story_lines, summary_lines, tokenizer):
|
||||
""" Encode the story and summary lines, and join them
|
||||
as specified in [1] by using `[SEP] [CLS]` tokens to separate
|
||||
sentences.
|
||||
"""
|
||||
story_lines_token_ids = [
|
||||
tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line))
|
||||
for line in story_lines
|
||||
]
|
||||
summary_lines_token_ids = [
|
||||
tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line))
|
||||
for line in summary_lines
|
||||
]
|
||||
|
||||
story_token_ids = [
|
||||
token for sentence in story_lines_token_ids for token in sentence
|
||||
]
|
||||
summary_token_ids = [
|
||||
token for sentence in summary_lines_token_ids for token in sentence
|
||||
]
|
||||
|
||||
return story_token_ids, summary_token_ids
|
||||
|
||||
|
||||
def _add_missing_period(line):
|
||||
END_TOKENS = [".", "!", "?", "...", "'", "`", '"', u"\u2019", u"\u2019", ")"]
|
||||
if line.startswith("@highlight"):
|
||||
return line
|
||||
if line[-1] in END_TOKENS:
|
||||
return line
|
||||
return line + "."
|
||||
|
||||
|
||||
def _fit_to_block_size(sequence, block_size):
|
||||
""" Adapt the source and target sequences' lengths to the block size.
|
||||
If the sequence is shorter than the block size we pad it with -1 ids
|
||||
which correspond to padding tokens.
|
||||
"""
|
||||
if len(sequence) > block_size:
|
||||
return sequence[:block_size]
|
||||
else:
|
||||
sequence.extend([0] * (block_size - len(sequence)))
|
||||
return sequence
|
||||
|
||||
|
||||
def mask_padding_tokens(sequence):
|
||||
""" Padding token, encoded as 0, are represented by the value -1 in the
|
||||
masks """
|
||||
padded = sequence.clone()
|
||||
padded[padded == 0] = -1
|
||||
return padded
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer):
|
||||
dataset = TextDataset(tokenizer, data_dir=args.data_dir)
|
||||
return dataset
|
||||
|
||||
|
||||
def compute_token_type_ids(batch, separator_token_id):
|
||||
""" Segment embeddings as described in [1]
|
||||
|
||||
The values {0,1} were found in the repository [2].
|
||||
|
||||
Attributes:
|
||||
batch: torch.Tensor, size [batch_size, block_size]
|
||||
Batch of input.
|
||||
separator_token_id: int
|
||||
The value of the token that separates the segments.
|
||||
|
||||
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
|
||||
arXiv preprint arXiv:1908.08345 (2019).
|
||||
[2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217)
|
||||
"""
|
||||
batch_embeddings = []
|
||||
sentence_num = 0
|
||||
for sequence in batch:
|
||||
embeddings = []
|
||||
for s in sequence:
|
||||
if s == separator_token_id:
|
||||
sentence_num += 1
|
||||
embeddings.append(sentence_num % 2)
|
||||
batch_embeddings.append(embeddings)
|
||||
return torch.tensor(batch_embeddings)
|
||||
|
||||
|
||||
# ----------
|
||||
# Optimizers
|
||||
# ----------
|
||||
|
||||
|
||||
class BertSumOptimizer(object):
|
||||
""" Specific optimizer for BertSum.
|
||||
|
||||
As described in [1], the authors fine-tune BertSum for abstractive
|
||||
summarization using two Adam Optimizers with different warm-up steps and
|
||||
learning rate. They also use a custom learning rate scheduler.
|
||||
|
||||
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
|
||||
arXiv preprint arXiv:1908.08345 (2019).
|
||||
"""
|
||||
|
||||
def __init__(self, model, lr, warmup_steps, beta_1=0.99, beta_2=0.999, eps=1e-9):
|
||||
self.encoder = model.encoder
|
||||
self.decoder = model.decoder
|
||||
self.lr = lr
|
||||
self.warmup_steps = warmup_steps
|
||||
|
||||
self.optimizers = {
|
||||
"encoder": Adam(
|
||||
model.encoder.parameters(),
|
||||
lr=lr["encoder"],
|
||||
betas=(beta_1, beta_2),
|
||||
eps=eps,
|
||||
),
|
||||
"decoder": Adam(
|
||||
model.decoder.parameters(),
|
||||
lr=lr["decoder"],
|
||||
betas=(beta_1, beta_2),
|
||||
eps=eps,
|
||||
),
|
||||
}
|
||||
|
||||
self._step = 0
|
||||
|
||||
def _update_rate(self, stack):
|
||||
return self.lr[stack] * min(
|
||||
self._step ** (-0.5), self._step * self.warmup_steps[stack] ** (-0.5)
|
||||
)
|
||||
|
||||
def zero_grad(self):
|
||||
self.optimizer_decoder.zero_grad()
|
||||
self.optimizer_encoder.zero_grad()
|
||||
|
||||
def step(self):
|
||||
self._step += 1
|
||||
for stack, optimizer in self.optimizers.items():
|
||||
new_rate = self._update_rate(stack)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = new_rate
|
||||
optimizer.step()
|
||||
|
||||
|
||||
# ------------
|
||||
# Train
|
||||
# ------------
|
||||
|
||||
|
||||
def train(args, model, tokenizer):
|
||||
""" Fine-tune the pretrained model on the corpus. """
|
||||
set_seed(args)
|
||||
|
||||
# Load the data
|
||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||
train_dataset = load_and_cache_examples(args, tokenizer)
|
||||
train_sampler = RandomSampler(train_dataset)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset, sampler=train_sampler, batch_size=args.train_batch_size
|
||||
)
|
||||
|
||||
# Training schedule
|
||||
if args.max_steps > 0:
|
||||
t_total = args.max_steps
|
||||
args.num_train_epochs = t_total // (
|
||||
len(train_dataloader) // args.gradient_accumulation_steps + 1
|
||||
)
|
||||
else:
|
||||
t_total = (
|
||||
len(train_dataloader)
|
||||
// args.gradient_accumulation_steps
|
||||
* args.num_train_epochs
|
||||
)
|
||||
|
||||
# Prepare the optimizer
|
||||
lr = {"encoder": 0.002, "decoder": 0.2}
|
||||
warmup_steps = {"encoder": 20000, "decoder": 10000}
|
||||
optimizer = BertSumOptimizer(model, lr, warmup_steps)
|
||||
|
||||
# Train
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(
|
||||
" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size
|
||||
)
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size * args.gradient_accumulation_steps
|
||||
# * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
model.zero_grad()
|
||||
train_iterator = trange(args.num_train_epochs, desc="Epoch", disable=True)
|
||||
|
||||
global_step = 0
|
||||
tr_loss = 0.0
|
||||
for _ in train_iterator:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True)
|
||||
for step, batch in enumerate(epoch_iterator):
|
||||
source, target = batch
|
||||
token_type_ids = compute_token_type_ids(source, tokenizer.cls_token_id)
|
||||
labels_src = mask_padding_tokens(source)
|
||||
labels_tgt = mask_padding_tokens(target)
|
||||
|
||||
source = source.to(args.device)
|
||||
target = target.to(args.device)
|
||||
token_type_ids = token_type_ids.to(args.device)
|
||||
labels_src = labels_src.to(args.device)
|
||||
labels_tgt = labels_tgt.to(args.device)
|
||||
|
||||
model.train()
|
||||
outputs = model(
|
||||
source,
|
||||
target,
|
||||
token_type_ids=token_type_ids,
|
||||
decoder_encoder_attention_mask=labels_src,
|
||||
decoder_attention_mask=labels_tgt,
|
||||
decoder_lm_labels=labels_tgt,
|
||||
decoder_initialize_randomly=True,
|
||||
)
|
||||
|
||||
loss = outputs[0]
|
||||
print(loss)
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss /= args.gradient_accumulation_steps
|
||||
|
||||
loss.backward()
|
||||
|
||||
tr_loss += loss.item()
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
optimizer.step()
|
||||
model.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
epoch_iterator.close()
|
||||
break
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
train_iterator.close()
|
||||
break
|
||||
|
||||
return global_step, tr_loss / global_step
|
||||
|
||||
|
||||
# ------------
|
||||
# Train
|
||||
# ------------
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, prefix=""):
|
||||
set_seed(args)
|
||||
|
||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||
eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)
|
||||
eval_sampler = SequentialSampler(eval_dataset)
|
||||
eval_dataloader = DataLoader(
|
||||
eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size
|
||||
)
|
||||
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
logger.info(" Num examples = %d", len(eval_dataset))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
eval_loss = 0.0
|
||||
nb_eval_steps = 0
|
||||
model.eval()
|
||||
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
source, target = batch
|
||||
labels_src = mask_padding_tokens(source)
|
||||
labels_tgt = mask_padding_tokens(target)
|
||||
source.to(args.device)
|
||||
target.to(args.device)
|
||||
labels_src.to(args.device)
|
||||
labels_tgt.to(args.device)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(
|
||||
source,
|
||||
target,
|
||||
decoder_encoder_attention_mask=labels_src,
|
||||
decoder_attention_mask=labels_tgt,
|
||||
decoder_lm_labels=labels_tgt,
|
||||
)
|
||||
lm_loss = outputs[0]
|
||||
eval_loss += lm_loss.mean().item()
|
||||
nb_eval_steps += 1
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
perplexity = torch.exp(torch.tensor(eval_loss))
|
||||
|
||||
result = {"perplexity": perplexity}
|
||||
|
||||
# Save the evaluation's results
|
||||
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results {} *****".format(prefix))
|
||||
for key in sorted(result.keys()):
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input training data file (a text file).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
|
||||
# Optional parameters
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_evaluate",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Run model evaluation on out-of-sample data.",
|
||||
)
|
||||
parser.add_argument("--do_train", type=bool, default=False, help="Run training.")
|
||||
parser.add_argument(
|
||||
"--do_overwrite_output_dir",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Whether to overwrite the output dir.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default="bert-base-cased",
|
||||
type=str,
|
||||
help="The model checkpoint to initialize the encoder and decoder's weights with.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default="bert",
|
||||
type=str,
|
||||
help="The decoder architecture to be fine-tuned.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to_cpu", default=False, type=bool, help="Whether to force training on CPU."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_train_epochs",
|
||||
default=1,
|
||||
type=int,
|
||||
help="Total number of training epochs to perform.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_gpu_train_batch_size",
|
||||
default=4,
|
||||
type=int,
|
||||
help="Batch size per GPU/CPU for training.",
|
||||
)
|
||||
parser.add_argument("--seed", default=42, type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
if (
|
||||
os.path.exists(args.output_dir)
|
||||
and os.listdir(args.output_dir)
|
||||
and args.do_train
|
||||
and not args.do_overwrite_output_dir
|
||||
):
|
||||
raise ValueError(
|
||||
"Output directory ({}) already exists and is not empty. Use --do_overwrite_output_dir to overwrite.".format(
|
||||
args.output_dir
|
||||
)
|
||||
)
|
||||
|
||||
# Set up training device
|
||||
if args.to_cpu or not torch.cuda.is_available():
|
||||
args.device = torch.device("cpu")
|
||||
args.n_gpu = 0
|
||||
else:
|
||||
args.device = torch.device("cuda")
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
|
||||
# Load pretrained model and tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
||||
model = Model2Model.from_pretrained(args.model_name_or_path)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.warning(
|
||||
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
||||
0,
|
||||
args.device,
|
||||
args.n_gpu,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Train the model
|
||||
model.to(args.device)
|
||||
if args.do_train:
|
||||
global_step, tr_loss = train(args, model, tokenizer)
|
||||
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
logger.info("Saving model checkpoint to %s", args.output_dir)
|
||||
|
||||
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
) # Take care of distributed/parallel training
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
torch.save(args, os.path.join(args.output_dir, "training_arguments.bin"))
|
||||
|
||||
# Evaluate the model
|
||||
results = {}
|
||||
if args.do_evaluate:
|
||||
checkpoints = []
|
||||
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
for checkpoint in checkpoints:
|
||||
encoder_checkpoint = os.path.join(checkpoint, "encoder")
|
||||
decoder_checkpoint = os.path.join(checkpoint, "decoder")
|
||||
model = PreTrainedSeq2seq.from_pretrained(
|
||||
encoder_checkpoint, decoder_checkpoint
|
||||
)
|
||||
model.to(args.device)
|
||||
results = "placeholder"
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
from run_seq2seq_finetuning import _fit_to_block_size, process_story
|
||||
from run_summarization_finetuning import _fit_to_block_size, process_story
|
||||
|
||||
|
||||
class DataLoaderTest(unittest.TestCase):
|
||||
@@ -43,15 +43,16 @@ class DataLoaderTest(unittest.TestCase):
|
||||
raw_story = """It was the year of Our Lord one thousand seven hundred and
|
||||
seventy-five.\n\nSpiritual revelations were conceded to England at that
|
||||
favoured period, as at this."""
|
||||
with self.assertRaises(IndexError):
|
||||
process_story(raw_story)
|
||||
_, summary = process_story(raw_story)
|
||||
self.assertEqual(summary, [])
|
||||
|
||||
def test_process_empty_story(self):
|
||||
""" An empty story should also raise and exception.
|
||||
"""
|
||||
raw_story = ""
|
||||
with self.assertRaises(IndexError):
|
||||
process_story(raw_story)
|
||||
story, summary = process_story(raw_story)
|
||||
self.assertEqual(story, [])
|
||||
self.assertEqual(summary, [])
|
||||
|
||||
def test_story_with_missing_period(self):
|
||||
raw_story = (
|
||||
@@ -59,17 +60,16 @@ class DataLoaderTest(unittest.TestCase):
|
||||
"seventy-five\n\nSpiritual revelations were conceded to England "
|
||||
"at that favoured period, as at this.\n@highlight\n\nIt was the best of times"
|
||||
)
|
||||
story, summary = process_story(raw_story)
|
||||
story_lines, summary_lines = process_story(raw_story)
|
||||
|
||||
expected_story = (
|
||||
"It was the year of Our Lord one thousand seven hundred and "
|
||||
"seventy-five. Spiritual revelations were conceded to England at that "
|
||||
"favoured period, as at this."
|
||||
)
|
||||
self.assertEqual(expected_story, story)
|
||||
expected_story_lines = [
|
||||
"It was the year of Our Lord one thousand seven hundred and seventy-five.",
|
||||
"Spiritual revelations were conceded to England at that favoured period, as at this.",
|
||||
]
|
||||
self.assertEqual(expected_story_lines, story_lines)
|
||||
|
||||
expected_summary = "It was the best of times."
|
||||
self.assertEqual(expected_summary, summary)
|
||||
expected_summary_lines = ["It was the best of times."]
|
||||
self.assertEqual(expected_summary_lines, summary_lines)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -87,7 +87,7 @@ if is_torch_available():
|
||||
from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel,
|
||||
DistilBertForSequenceClassification, DistilBertForQuestionAnswering,
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
from .modeling_seq2seq import Model2Model
|
||||
from .modeling_seq2seq import PreTrainedSeq2seq, Model2Model
|
||||
|
||||
# Optimization
|
||||
from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule,
|
||||
|
||||
240
transformers/modeling_beam_search.py
Normal file
240
transformers/modeling_beam_search.py
Normal file
@@ -0,0 +1,240 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2019 Yang Liu
|
||||
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
"""
|
||||
A general wrapper around models with LM heads to generate sequences
|
||||
using beam search.
|
||||
"""
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ModelWithBeamSearch(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
beam_size,
|
||||
start_token_id,
|
||||
end_token_id,
|
||||
pad_token_id,
|
||||
min_length,
|
||||
max_length,
|
||||
alpha,
|
||||
block_trigram=True,
|
||||
):
|
||||
"""
|
||||
Attributes:
|
||||
mask_word_id: token id that corresponds to the mask
|
||||
"""
|
||||
super(ModelWithBeamSearch, self).__init__()
|
||||
self.model = model
|
||||
self.beam_size = beam_size
|
||||
self.start_token_id = start_token_id
|
||||
self.end_token_id = end_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.min_length = min_length
|
||||
self.max_length = max_length
|
||||
self.alpha = alpha
|
||||
self.block_trigram = block_trigram
|
||||
|
||||
def forward(self, input_ids, **kwargs):
|
||||
# Separate the encoder- and decoder- specific kwargs. A kwarg is
|
||||
# decoder-specific it the key starts with `decoder_`
|
||||
kwargs_encoder = {
|
||||
argument: value
|
||||
for argument, value in kwargs.items()
|
||||
if not argument.startswith("decoder_")
|
||||
}
|
||||
kwargs_decoder = {
|
||||
argument[len("decoder_"):]: value
|
||||
for argument, value in kwargs.items()
|
||||
if argument.startswith("decoder_")
|
||||
}
|
||||
|
||||
batch_size, _ = input_ids.size(0)
|
||||
|
||||
# Variables that keep track of the status of the search
|
||||
hypotheses = [[] for _ in range(batch_size)]
|
||||
batch_offset = torch.arange(batch_size, dtype=torch.long)
|
||||
beam_offset = torch.arange(
|
||||
0,
|
||||
batch_size * self.beam_size,
|
||||
step=self.beam_size,
|
||||
dtype=torch.long,
|
||||
)
|
||||
growing_beam = torch.full(
|
||||
(batch_size * self.beam_size, 1),
|
||||
self.start_token_id,
|
||||
dtype=torch.long,
|
||||
)
|
||||
topk_log_probabilities = torch.tensor(
|
||||
[0.0] + [float("-inf")] * (self.beam_size - 1),
|
||||
dtype=torch.float,
|
||||
).repeat(batch_size)
|
||||
|
||||
# Forward pass on the encoder
|
||||
encoder_outputs = self.encoder(input_ids, kwargs_encoder)
|
||||
kwargs_decoder["encoder_hidden_states"] = tile(
|
||||
encoder_outputs, self.beam_size, dim=0
|
||||
)
|
||||
|
||||
results = {}
|
||||
results["predictions"] = [[] for _ in batch_size]
|
||||
results["scores"] = [[] for _ in batch_size]
|
||||
|
||||
for step in range(self.max_length):
|
||||
decoder_input = growing_beam[:, -1]
|
||||
outputs = self.decoder(decoder_input, kwargs_decoder)
|
||||
log_probabilities = torch.nn.functional.log_softmax(outputs[1])
|
||||
vocab_size = log_probabilities.size(-1)
|
||||
|
||||
# The batch size changes as some beams finish so we define:
|
||||
_B = log_probabilities.size(0) // self.beam_size
|
||||
|
||||
# Multiply each beam probability with the probability of the
|
||||
# next token (conditioned on the words in the beam).
|
||||
log_probabilities += topk_log_probabilities.view(-1, 1)
|
||||
|
||||
# if the beam has not attained the minimum required length we
|
||||
# make the end token arbitrarily unlikely.
|
||||
if step < self.min_length:
|
||||
log_probabilities[self.end_token_id] = -1e20
|
||||
|
||||
# Remove repeating tri-grams
|
||||
if(self.args.block_trigram):
|
||||
if(step + 1 > 3):
|
||||
for i in range(_B * self.beam_size):
|
||||
tokens = [t for t in growing_beam[i]]
|
||||
trigrams = [(tokens[i-1], tokens[i], tokens[i+1]) for i in range(1, len(words) - 1)]
|
||||
last_trigram = tuple(trigrams[-1])
|
||||
if last_trigram in trigrams[:-1]:
|
||||
log_probabilities[i] = -1e20
|
||||
|
||||
# Find the `beam_size` (previous_beam + token) combinations with
|
||||
# the highest score
|
||||
topk_log_probabilities, topk_ids = log_probabilities.topk(
|
||||
log_probabilities.view(_B, self.beam_size * vocab_size),
|
||||
self.beam_size,
|
||||
dim=1
|
||||
)
|
||||
|
||||
# Apply the length penalty. The +1 accounts for the [EOS] token
|
||||
# that will be added if the beam ends.
|
||||
length_penalty = ((5.0 + (step + 1)) / 6.0) ** self.alpha
|
||||
topk_scores = topk_log_probabilities / length_penalty
|
||||
|
||||
# Retrieve the corresponding respective beam and token id
|
||||
# topk_token_ids[i] will be added to topk_beam_ids[i]
|
||||
topk_beam_ids = topk_ids.div(vocab_size)
|
||||
topk_token_ids = topk_ids.fmod(vocab_size)
|
||||
|
||||
# Retrieve the row index of the surviving beams in the original
|
||||
# view of the log_probabilities tensor
|
||||
surviving_beams_rows = (
|
||||
topk_beam_ids + beam_offset[:_B].view(-1, 1)
|
||||
).view(-1)
|
||||
|
||||
# Append the last predictions
|
||||
growing_beam = torch.cat(
|
||||
[
|
||||
growing_beam.index_select(0, surviving_beams_rows),
|
||||
topk_token_ids.view(-1, 1),
|
||||
],
|
||||
1,
|
||||
)
|
||||
|
||||
# Check if any of the beam searches has ended during this
|
||||
# growth step. Also if top beam (most probable) has ended
|
||||
# for one element of the batch.
|
||||
is_finished = topk_token_ids.eq(self.end_token_id)
|
||||
if step + 1 == self.max_length:
|
||||
is_finished.fill_(1)
|
||||
is_top_beam_finished = is_finished[:, 0].eq(1)
|
||||
|
||||
# Save the finished searches
|
||||
if is_finished.any():
|
||||
predictions = growing_beam.view(-1, self.beam_size, growing_beam.size(1))
|
||||
for i in range(is_finished.size(0)):
|
||||
if is_top_beam_finished[i]:
|
||||
is_finished[i].fill_(1)
|
||||
finished_hyp = is_finished[i].nonzero().view(-1)
|
||||
|
||||
# Store finished hypotheses for this batch.
|
||||
b = batch_offset[i]
|
||||
for j in finished_hyp:
|
||||
hypotheses[b].append((topk_scores[i, j], predictions[i, j, :]))
|
||||
|
||||
# If the batch reached the end, save the best hypotheses
|
||||
# in terms of length-penalized score.
|
||||
if is_top_beam_finished[i]:
|
||||
best_hyp = sorted(
|
||||
hypotheses[b], key=lambda x: x[0], reverse=True
|
||||
)
|
||||
best_score, best_prediction = best_hyp[0]
|
||||
results["scores"][b].append(best_score)
|
||||
results["predictions"][b].append(best_prediction)
|
||||
|
||||
non_finished = is_top_beam_finished.eq(0).nonzero().view(-1)
|
||||
if len(non_finished) == 0:
|
||||
break
|
||||
|
||||
# Remove finished batches for the next step.
|
||||
topk_log_probabilities = topk_log_probabilities.index_select(0, non_finished)
|
||||
batch_offset = batch_offset.index_select(0, non_finished)
|
||||
growing_beam = predictions.index_select(0, non_finished).view(
|
||||
-1, growing_beam.size(-1)
|
||||
)
|
||||
|
||||
# Re-order the state for the next pass
|
||||
surviving_beams_rows = surviving_beams_rows.index_select(0, non_finished)
|
||||
kwargs_decoder["encoder_hidden_states"] = kwargs_decoder[
|
||||
"encoder_hidden_states"
|
||||
].index_select(0, surviving_beams_rows)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def tile(x, count, dim=0):
|
||||
"""
|
||||
Tiles `x` along dimension `dim` `count` times.
|
||||
|
||||
Example:
|
||||
>> ex = torch.tensor([1,2],[3,4])
|
||||
>> tile(ex, 2, 0)
|
||||
torch.Tensor([[1,2],[1,2],[3,4],[3,4]])
|
||||
"""
|
||||
perm = list(range(len(x.size())))
|
||||
if dim != 0:
|
||||
perm[0], perm[dim] = perm[dim], perm[0]
|
||||
x = x.permute(perm).contiguous()
|
||||
out_size = list(x.size())
|
||||
out_size[0] *= count
|
||||
batch = x.size(0)
|
||||
x = (
|
||||
x.view(batch, -1)
|
||||
.transpose(0, 1)
|
||||
.repeat(count, 1)
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
.view(*out_size)
|
||||
)
|
||||
if dim != 0:
|
||||
x = x.permute(perm).contiguous()
|
||||
return x
|
||||
@@ -646,7 +646,7 @@ class BertModel(BertPreTrainedModel):
|
||||
if attention_mask.dim() == 2:
|
||||
if self.config.is_decoder:
|
||||
batch_size, seq_length = input_ids.size()
|
||||
seq_ids = torch.arange(seq_length)
|
||||
seq_ids = torch.arange(seq_length, device=input_ids.device)
|
||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||
else:
|
||||
@@ -660,6 +660,13 @@ class BertModel(BertPreTrainedModel):
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
# If a 2D encoder attention mask is provided for the cross-attention
|
||||
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
||||
if encoder_attention_mask is not None:
|
||||
encoder_attention_mask = encoder_attention_mask[:, None, None, :]
|
||||
encoder_attention_mask = encoder_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
encoder_attention_mask = (1.0 - encoder_attention_mask) * -10000.0
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
@@ -819,7 +826,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
self.bert.embeddings.word_embeddings)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
|
||||
masked_lm_labels=None, lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None):
|
||||
masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ):
|
||||
|
||||
outputs = self.bert(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@@ -838,11 +845,8 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
# 1. If a tensor that contains the indices of masked labels is provided,
|
||||
# the cross-entropy is the MLM cross-entropy that measures the likelihood
|
||||
# of predictions for masked words.
|
||||
# 2. If encoder hidden states are provided we are in a causal situation where we
|
||||
# 2. If `lm_label` is provided we are in a causal scenario where we
|
||||
# try to predict the next word for each input in the encoder.
|
||||
if masked_lm_labels is not None and lm_labels is not None:
|
||||
raise AttributeError("Masked LM training with an encoder-decoder is not supported.")
|
||||
|
||||
if masked_lm_labels is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1) # -1 index = padding token
|
||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
||||
@@ -851,9 +855,9 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
if lm_labels is not None:
|
||||
# we are doing next-token prediction; shift prediction scores and input ids by one
|
||||
prediction_scores = prediction_scores[:, :-1, :]
|
||||
lm_labels = lm_labels[:, 1:, :]
|
||||
lm_labels = lm_labels[:, 1:]
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
seq2seq_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1))
|
||||
seq2seq_loss = loss_fct(prediction_scores.reshape(-1, self.config.vocab_size), lm_labels.reshape(-1))
|
||||
outputs = (seq2seq_loss,) + outputs
|
||||
|
||||
return outputs # (mlm_or_seq2seq_loss), prediction_scores, (hidden_states), (attentions)
|
||||
|
||||
@@ -17,13 +17,12 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .file_utils import add_start_docstrings
|
||||
from .modeling_auto import AutoModel, AutoModelWithLMHead
|
||||
from .modeling_utils import PreTrainedModel, SequenceSummary
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -43,7 +42,13 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
self.decoder = decoder
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, encoder_pretrained_model_name_or_path=None, decoder_pretrained_model_name_or_path=None, *model_args, **kwargs):
|
||||
def from_pretrained(
|
||||
cls,
|
||||
encoder_pretrained_model_name_or_path=None,
|
||||
decoder_pretrained_model_name_or_path=None,
|
||||
*model_args,
|
||||
**kwargs
|
||||
):
|
||||
r""" Instantiates an encoder and a decoder from one or two base classes
|
||||
of the library from pre-trained model checkpoints.
|
||||
|
||||
@@ -108,23 +113,28 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
|
||||
# Separate the encoder- and decoder- specific kwargs. A kwarg is
|
||||
# decoder-specific it the key starts with `decoder_`
|
||||
kwargs_decoder = {}
|
||||
kwargs_encoder = kwargs
|
||||
for key in kwargs_encoder.keys():
|
||||
if key.startswith("decoder_"):
|
||||
kwargs_decoder[key.replace("decoder_", "")] = kwargs_encoder.pop(key)
|
||||
kwargs_encoder = {
|
||||
argument: value
|
||||
for argument, value in kwargs.items()
|
||||
if not argument.startswith("decoder_")
|
||||
}
|
||||
kwargs_decoder = {
|
||||
argument[len("decoder_") :]: value
|
||||
for argument, value in kwargs.items()
|
||||
if argument.startswith("decoder_")
|
||||
}
|
||||
|
||||
# Load and initialize the encoder and decoder
|
||||
# The distinction between encoder and decoder at the model level is made
|
||||
# by the value of the flag `is_decoder` that we need to set correctly.
|
||||
encoder = kwargs.pop("encoder_model", None)
|
||||
encoder = kwargs_encoder.pop("encoder_model", None)
|
||||
if encoder is None:
|
||||
kwargs_encoder["is_decoder"] = False
|
||||
encoder = AutoModel.from_pretrained(
|
||||
encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
|
||||
)
|
||||
|
||||
decoder = kwargs.pop("decoder_model", None)
|
||||
decoder = kwargs_decoder.pop("model", None)
|
||||
if decoder is None:
|
||||
kwargs_decoder["is_decoder"] = True
|
||||
decoder = AutoModelWithLMHead.from_pretrained(
|
||||
@@ -135,6 +145,12 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
|
||||
return model
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
""" Save a Seq2Seq model and its configuration file in a format
|
||||
such that it can be loaded using `:func:`~transformers.PreTrainedSeq2seq.from_pretrained` """
|
||||
self.encoder.save_pretrained(os.path.join(save_directory, "encoder"))
|
||||
self.decoder.save_pretrained(os.path.join(save_directory, "decoder"))
|
||||
|
||||
def forward(self, encoder_input_ids, decoder_input_ids, **kwargs):
|
||||
""" The forward pass on a seq2eq depends what we are performing:
|
||||
|
||||
@@ -155,22 +171,29 @@ class PreTrainedSeq2seq(nn.Module):
|
||||
"""
|
||||
# Separate the encoder- and decoder- specific kwargs. A kwarg is
|
||||
# decoder-specific it the key starts with `decoder_`
|
||||
kwargs_decoder = {}
|
||||
kwargs_encoder = kwargs
|
||||
for key in kwargs_encoder.keys():
|
||||
if key.startswith("decoder_"):
|
||||
kwargs_decoder[key.replace("decoder_", "")] = kwargs_encoder.pop(key)
|
||||
kwargs_encoder = {
|
||||
argument: value
|
||||
for argument, value in kwargs.items()
|
||||
if not argument.startswith("decoder_")
|
||||
}
|
||||
kwargs_decoder = {
|
||||
argument[len("decoder_") :]: value
|
||||
for argument, value in kwargs.items()
|
||||
if argument.startswith("decoder_")
|
||||
}
|
||||
|
||||
# Encode if needed (training, first prediction pass)
|
||||
encoder_hidden_states = kwargs_encoder.pop("encoder_hidden_states", None)
|
||||
if encoder_hidden_states is None:
|
||||
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
|
||||
encoder_hidden_states = encoder_outputs[0][-1] # output of the encoder *stack*
|
||||
encoder_hidden_states = encoder_outputs[0][
|
||||
-1
|
||||
] # output of the encoder *stack*
|
||||
else:
|
||||
encoder_outputs = ()
|
||||
|
||||
# Decode
|
||||
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
|
||||
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states[None, :, :]
|
||||
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
|
||||
|
||||
return decoder_outputs + encoder_outputs
|
||||
@@ -201,9 +224,25 @@ class Model2Model(PreTrainedSeq2seq):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
||||
model = super(Model2Model, cls).from_pretrained(encoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
|
||||
if (
|
||||
"bert" not in pretrained_model_name_or_path
|
||||
or "roberta" in pretrained_model_name_or_path
|
||||
or "distilbert" in pretrained_model_name_or_path
|
||||
):
|
||||
raise ValueError("Only the Bert model is currently supported.")
|
||||
|
||||
model = super(Model2Model, cls).from_pretrained(
|
||||
encoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
decoder_pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
**kwargs)
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Some architectures require for the decoder to be initialized randomly
|
||||
# before fine-tuning.
|
||||
if kwargs.get("decoder_initialize_randomly", False):
|
||||
model.decoder.init_weights()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user