From 4c3ac4a7d83cdf37b796d783bb66a89bbd09ef9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Fri, 18 Oct 2019 12:29:30 +0200 Subject: [PATCH] here's one big commit --- examples/README.md | 5 +- examples/run_seq2seq_finetuning.py | 361 ---------- examples/run_summarization_finetuning.py | 620 ++++++++++++++++++ ...y => run_summarization_finetuning_test.py} | 28 +- transformers/__init__.py | 2 +- transformers/modeling_beam_search.py | 240 +++++++ transformers/modeling_bert.py | 20 +- transformers/modeling_seq2seq.py | 83 ++- 8 files changed, 951 insertions(+), 408 deletions(-) delete mode 100644 examples/run_seq2seq_finetuning.py create mode 100644 examples/run_summarization_finetuning.py rename examples/{run_seq2seq_finetuning_test.py => run_summarization_finetuning_test.py} (79%) create mode 100644 transformers/modeling_beam_search.py diff --git a/examples/README.md b/examples/README.md index e0fe1fc704..bec6d57171 100644 --- a/examples/README.md +++ b/examples/README.md @@ -393,7 +393,8 @@ This fine-tuned model is available as a checkpoint under the reference ## Seq2seq model fine-tuning -Based on the script [`run_seq2seq_finetuning.py`](https://github.com/huggingface/transformers/blob/master/examples/run_seq2seq_finetuning.py). +Based on the script +[`run_summarization_finetuning.py`](https://github.com/huggingface/transformers/blob/master/examples/run_summarization_finetuning.py). Before running this script you should download **both** CNN and Daily Mail datasets from [Kyunghyun Cho's website](https://cs.nyu.edu/~kcho/DMQA/) (the @@ -412,7 +413,7 @@ archive. ```bash export DATA_PATH=/path/to/dataset/ -python run_seq2seq_finetuning.py \ +python run_summarization_finetuning.py \ --output_dir=output \ --model_type=bert2bert \ --model_name_or_path=bert2bert \ diff --git a/examples/run_seq2seq_finetuning.py b/examples/run_seq2seq_finetuning.py deleted file mode 100644 index 61c4abfe6e..0000000000 --- a/examples/run_seq2seq_finetuning.py +++ /dev/null @@ -1,361 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The Microsoft Reseach team and The HuggingFace Inc. team. -# Copyright (c) 2018 Microsoft and The HuggingFace Inc. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Finetuning seq2seq models for sequence generation.""" - -import argparse -from collections import deque -import logging -import pickle -import random -import os - -import numpy as np -from tqdm import tqdm, trange -import torch -from torch.utils.data import Dataset, RandomSampler - -from transformers import AutoTokenizer, Model2Model - -logger = logging.getLogger(__name__) - - -def set_seed(args): - random.seed(args.seed) - np.random.seed(args.seed) - torch.manual_seed(args.seed) - - -# ------------ -# Load dataset -# ------------ - - -class TextDataset(Dataset): - """ Abstracts the dataset used to train seq2seq models. - - CNN/Daily News: - - The CNN/Daily News raw datasets are downloaded from [1]. The stories are - stored in different files; the summary appears at the end of the story as - sentences that are prefixed by the special `@highlight` line. To process - the data, untar both datasets in the same folder, and pass the path to this - folder as the "data_dir argument. The formatting code was inspired by [2]. - - [1] https://cs.nyu.edu/~kcho/ - [2] https://github.com/abisee/cnn-dailymail/ - """ - - def __init__(self, tokenizer, prefix="train", data_dir="", block_size=512): - assert os.path.isdir(data_dir) - - # Load features that have already been computed if present - cached_features_file = os.path.join( - data_dir, "cached_lm_{}_{}".format(block_size, prefix) - ) - if os.path.exists(cached_features_file): - logger.info("Loading features from cached file %s", cached_features_file) - with open(cached_features_file, "rb") as source: - self.examples = pickle.load(source) - return - - logger.info("Creating features from dataset at %s", data_dir) - self.examples = [] - datasets = ["cnn", "dailymail"] - for dataset in datasets: - path_to_stories = os.path.join(data_dir, dataset, "stories") - assert os.path.isdir(path_to_stories) - - story_filenames_list = os.listdir(path_to_stories) - for story_filename in story_filenames_list: - path_to_story = os.path.join(path_to_stories, story_filename) - if not os.path.isfile(path_to_story): - continue - - with open(path_to_story, encoding="utf-8") as source: - try: - raw_story = source.read() - story, summary = process_story(raw_story) - except IndexError: # skip ill-formed stories - continue - - story = tokenizer.encode(story) - story_seq = _fit_to_block_size(story, block_size) - - summary = tokenizer.encode(summary) - summary_seq = _fit_to_block_size(summary, block_size) - - self.examples.append((story_seq, summary_seq)) - - logger.info("Saving features into cache file %s", cached_features_file) - with open(cached_features_file, "wb") as sink: - pickle.dump(self.examples, sink, protocol=pickle.HIGHEST_PROTOCOL) - - def __len__(self): - return len(self.examples) - - def __getitem__(self, items): - return torch.tensor(self.examples[items]) - - -def process_story(raw_story): - """ Extract the story and summary from a story file. - - Attributes: - raw_story (str): content of the story file as an utf-8 encoded string. - - Raises: - IndexError: If the stoy is empty or contains no highlights. - """ - file_lines = list( - filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]) - ) - - # for some unknown reason some lines miss a period, add it - file_lines = [_add_missing_period(line) for line in file_lines] - - # gather article lines - story_lines = [] - lines = deque(file_lines) - while True: - try: - element = lines.popleft() - if element.startswith("@highlight"): - break - story_lines.append(element) - except IndexError as ie: # if "@highlight" absent from file - raise ie - - # gather summary lines - highlights_lines = list(filter(lambda t: not t.startswith("@highlight"), lines)) - - # join the lines - story = " ".join(story_lines) - summary = " ".join(highlights_lines) - - return story, summary - - -def _add_missing_period(line): - END_TOKENS = [".", "!", "?", "...", "'", "`", '"', u"\u2019", u"\u2019", ")"] - if line.startswith("@highlight"): - return line - if line[-1] in END_TOKENS: - return line - return line + "." - - -def _fit_to_block_size(sequence, block_size): - """ Adapt the source and target sequences' lengths to the block size. - If the sequence is shorter than the block size we pad it with -1 ids - which correspond to padding tokens. - """ - if len(sequence) > block_size: - return sequence[:block_size] - else: - sequence.extend([0] * (block_size - len(sequence))) - return sequence - - -def mask_padding_tokens(sequence): - """ Replace the padding token with -1 values """ - return [s if s != 0 else -1 for s in sequence] - - -def load_and_cache_examples(args, tokenizer): - dataset = TextDataset(tokenizer, data_dir=args.data_dir) - return dataset - - -# ------------ -# Train -# ------------ - - -def train(args, train_dataset, model, tokenizer): - """ Fine-tune the pretrained model on the corpus. """ - - # Prepare the data loading - args.train_bach_size = 1 - train_sampler = RandomSampler(train_dataset) - train_dataloader = DataLoader( - train_dataset, sampler=train_sampler, batch_size=args.train_bach_size - ) - - # Prepare the optimizer and schedule (linear warmup and decay) - no_decay = ["bias", "LayerNorm.weight"] - optimizer_grouped_parameters = [ - { - "params": [ - p - for n, p in model.named_parameters() - if not any(nd in n for nd in no_decay) - ], - "weight_decay": args.weight_decay, - }, - { - "params": [ - p - for n, p in model.named_parameters() - if any(nd in n for nd in no_decay) - ], - "weight_decay": 0.0, - }, - ] - optimizer = AdamW( - optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon - ) - scheduler = WarmupLinearSchedule( - optimizer, warmup_steps=args.warmup_steps, t_total=t_total - ) - - # Train - logger.info("***** Running training *****") - logger.info(" Num examples = %d", len(train_dataset)) - logger.info(" Num Epochs = %d", args.num_train_epochs) - logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) - logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", - args.train_batch_size - * args.gradient_accumulation_steps - * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), - ) - logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) - logger.info(" Total optimization steps = %d", t_total) - - global_step = 0 - tr_loss, logging_loss = 0.0, 0.0 - model.zero_grad() - train_iterator = trange(args.num_train_epochs, desc="Epoch", disable=True) - set_seed(args) - for _ in train_iterator: - epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True) - for step, batch in enumerate(epoch_iterator): - source = ([s for s, _ in batch]).to(args.device) - target = ([t for _, t in batch]).to(args.device) - model.train() - outputs = model(source, target, decoder_lm_labels=mask_padding_tokens(target)) - loss = outputs[0] - loss.backward() - - tr_loss += loss.item() - if (step + 1) % args.gradient_accumulation_steps == 0: - torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) - optimizer.step() - scheduler.step() - model.zero_grad() - global_step += 1 - - if args.max_steps > 0 and global_step > args.max_steps: - epoch_iterator.close() - break - - if args.max_steps > 0 and global_step > args.max_steps: - train_iterator.close() - break - - return global_step, tr_loss / global_step - - -def main(): - parser = argparse.ArgumentParser() - - # Required parameters - parser.add_argument( - "--data_dir", - default=None, - type=str, - required=True, - help="The input training data file (a text file).", - ) - parser.add_argument( - "--output_dir", - default=None, - type=str, - required=True, - help="The output directory where the model predictions and checkpoints will be written.", - ) - - # Optional parameters - parser.add_argument( - "--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer." - ) - parser.add_argument( - "--model_name_or_path", - default="bert-base-cased", - type=str, - help="The model checkpoint to initialize the encoder and decoder's weights with.", - ) - parser.add_argument( - "--model_type", - default="bert", - type=str, - help="The decoder architecture to be fine-tuned.", - ) - parser.add_argument( - "--learning_rate", - default=5e-5, - type=float, - help="The initial learning rate for Adam.", - ) - parser.add_argument( - "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." - ) - parser.add_argument( - "--max_steps", - default=-1, - type=int, - help="If > 0: set total number of training steps to perform. Override num_train_epochs.", - ) - parser.add_argument( - "--num_train_epochs", - default=1, - type=int, - help="Total number of training epochs to perform.", - ) - parser.add_argument("--seed", default=42, type=int) - parser.add_argument( - "--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps." - ) - parser.add_argument( - "--weight_decay", default=0.0, type=float, help="Weight deay if we apply some." - ) - args = parser.parse_args() - - if args.model_type != "bert": - raise ValueError( - "Only the BERT architecture is currently supported for seq2seq." - ) - - # Set up training device - # device = torch.device("cpu") - - # Set seed - set_seed(args) - - # Load pretrained model and tokenizer - tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) - model = Model2Model.from_pretrained(args.model_name_or_path) - # model.to(device) - - logger.info("Training/evaluation parameters %s", args) - - # Training - train_dataset = load_and_cache_examples(args, tokenizer) - global_step, tr_loss = train(args, train_dataset, model, tokenizer) - # logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) - - -if __name__ == "__main__": - main() diff --git a/examples/run_summarization_finetuning.py b/examples/run_summarization_finetuning.py new file mode 100644 index 0000000000..64bee82c5b --- /dev/null +++ b/examples/run_summarization_finetuning.py @@ -0,0 +1,620 @@ +# coding=utf-8 +# Copyright 2019 The HuggingFace Inc. team. +# Copyright (c) 2019 The HuggingFace Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Finetuning seq2seq models for sequence generation.""" + +import argparse +from collections import deque +import logging +import os +import pickle +import random +import sys + +import numpy as np +from tqdm import tqdm, trange +import torch +from torch.optim import Adam +from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler + +from transformers import AutoTokenizer, PreTrainedSeq2seq, Model2Model + +logger = logging.getLogger(__name__) +logging.basicConfig(stream=sys.stdout, level=logging.INFO) + + +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + + +# ------------ +# Load dataset +# ------------ + + +class TextDataset(Dataset): + """ Abstracts the dataset used to train seq2seq models. + + CNN/Daily News: + + The CNN/Daily News raw datasets are downloaded from [1]. The stories are + stored in different files; the summary appears at the end of the story as + sentences that are prefixed by the special `@highlight` line. To process + the data, untar both datasets in the same folder, and pass the path to this + folder as the "data_dir argument. The formatting code was inspired by [2]. + + [1] https://cs.nyu.edu/~kcho/ + [2] https://github.com/abisee/cnn-dailymail/ + """ + + def __init__(self, tokenizer, prefix="train", data_dir="", block_size=512): + assert os.path.isdir(data_dir) + + # Load the features that have already been computed, if any + cached_features_file = os.path.join( + data_dir, "cached_lm_{}_{}".format(block_size, prefix) + ) + if os.path.exists(cached_features_file): + logger.info("Loading features from cached file %s", cached_features_file) + with open(cached_features_file, "rb") as source: + self.examples = pickle.load(source) + return + + logger.info("Creating features from dataset at %s", data_dir) + datasets = ["cnn", "dailymail"] + + self.examples = {"source": [], "target": []} + for dataset in datasets: + path_to_stories = os.path.join(data_dir, dataset, "stories") + story_filenames_list = os.listdir(path_to_stories) + for story_filename in story_filenames_list: + path_to_story = os.path.join(path_to_stories, story_filename) + if not os.path.isfile(path_to_story): + continue + + with open(path_to_story, encoding="utf-8") as source: + raw_story = source.read() + story_lines, summary_lines = process_story(raw_story) + if len(summary_lines) == 0 or len(story_lines) == 0: + continue + + story_token_ids, summary_token_ids = _encode_for_summarization( + story_lines, summary_lines, tokenizer + ) + story_seq = _fit_to_block_size(story_token_ids, block_size) + self.examples["source"].append(story_seq) + + summary_seq = _fit_to_block_size(summary_token_ids, block_size) + self.examples["summary"].append(summary_seq) + + logger.info("Saving features into cache file %s", cached_features_file) + with open(cached_features_file, "wb") as sink: + pickle.dump(self.examples, sink, protocol=pickle.HIGHEST_PROTOCOL) + + def __len__(self): + return len(self.examples) + + def __getitem__(self, items): + return ( + torch.tensor(self.examples["source"][items]), + torch.tensor(self.examples["target"][items]), + ) + + +def process_story(raw_story): + """ Extract the story and summary from a story file. + + Attributes: + raw_story (str): content of the story file as an utf-8 encoded string. + + Raises: + IndexError: If the stoy is empty or contains no highlights. + """ + nonempty_lines = list( + filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]) + ) + + # for some unknown reason some lines miss a period, add it + nonempty_lines = [_add_missing_period(line) for line in nonempty_lines] + + # gather article lines + story_lines = [] + lines = deque(nonempty_lines) + while True: + try: + element = lines.popleft() + if element.startswith("@highlight"): + break + story_lines.append(element) + except IndexError: + # if "@highlight" is absent from the file we pop + # all elements until there is None. + return story_lines, [] + + # gather summary lines + summary_lines = list(filter(lambda t: not t.startswith("@highlight"), lines)) + + return story_lines, summary_lines + + +def _encode_for_summarization(story_lines, summary_lines, tokenizer): + """ Encode the story and summary lines, and join them + as specified in [1] by using `[SEP] [CLS]` tokens to separate + sentences. + """ + story_lines_token_ids = [ + tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line)) + for line in story_lines + ] + summary_lines_token_ids = [ + tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line)) + for line in summary_lines + ] + + story_token_ids = [ + token for sentence in story_lines_token_ids for token in sentence + ] + summary_token_ids = [ + token for sentence in summary_lines_token_ids for token in sentence + ] + + return story_token_ids, summary_token_ids + + +def _add_missing_period(line): + END_TOKENS = [".", "!", "?", "...", "'", "`", '"', u"\u2019", u"\u2019", ")"] + if line.startswith("@highlight"): + return line + if line[-1] in END_TOKENS: + return line + return line + "." + + +def _fit_to_block_size(sequence, block_size): + """ Adapt the source and target sequences' lengths to the block size. + If the sequence is shorter than the block size we pad it with -1 ids + which correspond to padding tokens. + """ + if len(sequence) > block_size: + return sequence[:block_size] + else: + sequence.extend([0] * (block_size - len(sequence))) + return sequence + + +def mask_padding_tokens(sequence): + """ Padding token, encoded as 0, are represented by the value -1 in the + masks """ + padded = sequence.clone() + padded[padded == 0] = -1 + return padded + + +def load_and_cache_examples(args, tokenizer): + dataset = TextDataset(tokenizer, data_dir=args.data_dir) + return dataset + + +def compute_token_type_ids(batch, separator_token_id): + """ Segment embeddings as described in [1] + + The values {0,1} were found in the repository [2]. + + Attributes: + batch: torch.Tensor, size [batch_size, block_size] + Batch of input. + separator_token_id: int + The value of the token that separates the segments. + + [1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders." + arXiv preprint arXiv:1908.08345 (2019). + [2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217) + """ + batch_embeddings = [] + sentence_num = 0 + for sequence in batch: + embeddings = [] + for s in sequence: + if s == separator_token_id: + sentence_num += 1 + embeddings.append(sentence_num % 2) + batch_embeddings.append(embeddings) + return torch.tensor(batch_embeddings) + + +# ---------- +# Optimizers +# ---------- + + +class BertSumOptimizer(object): + """ Specific optimizer for BertSum. + + As described in [1], the authors fine-tune BertSum for abstractive + summarization using two Adam Optimizers with different warm-up steps and + learning rate. They also use a custom learning rate scheduler. + + [1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders." + arXiv preprint arXiv:1908.08345 (2019). + """ + + def __init__(self, model, lr, warmup_steps, beta_1=0.99, beta_2=0.999, eps=1e-9): + self.encoder = model.encoder + self.decoder = model.decoder + self.lr = lr + self.warmup_steps = warmup_steps + + self.optimizers = { + "encoder": Adam( + model.encoder.parameters(), + lr=lr["encoder"], + betas=(beta_1, beta_2), + eps=eps, + ), + "decoder": Adam( + model.decoder.parameters(), + lr=lr["decoder"], + betas=(beta_1, beta_2), + eps=eps, + ), + } + + self._step = 0 + + def _update_rate(self, stack): + return self.lr[stack] * min( + self._step ** (-0.5), self._step * self.warmup_steps[stack] ** (-0.5) + ) + + def zero_grad(self): + self.optimizer_decoder.zero_grad() + self.optimizer_encoder.zero_grad() + + def step(self): + self._step += 1 + for stack, optimizer in self.optimizers.items(): + new_rate = self._update_rate(stack) + for param_group in optimizer.param_groups: + param_group["lr"] = new_rate + optimizer.step() + + +# ------------ +# Train +# ------------ + + +def train(args, model, tokenizer): + """ Fine-tune the pretrained model on the corpus. """ + set_seed(args) + + # Load the data + args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) + train_dataset = load_and_cache_examples(args, tokenizer) + train_sampler = RandomSampler(train_dataset) + train_dataloader = DataLoader( + train_dataset, sampler=train_sampler, batch_size=args.train_batch_size + ) + + # Training schedule + if args.max_steps > 0: + t_total = args.max_steps + args.num_train_epochs = t_total // ( + len(train_dataloader) // args.gradient_accumulation_steps + 1 + ) + else: + t_total = ( + len(train_dataloader) + // args.gradient_accumulation_steps + * args.num_train_epochs + ) + + # Prepare the optimizer + lr = {"encoder": 0.002, "decoder": 0.2} + warmup_steps = {"encoder": 20000, "decoder": 10000} + optimizer = BertSumOptimizer(model, lr, warmup_steps) + + # Train + logger.info("***** Running training *****") + logger.info(" Num examples = %d", len(train_dataset)) + logger.info(" Num Epochs = %d", args.num_train_epochs) + logger.info( + " Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size + ) + logger.info( + " Total train batch size (w. parallel, distributed & accumulation) = %d", + args.train_batch_size * args.gradient_accumulation_steps + # * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), + ) + logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) + logger.info(" Total optimization steps = %d", t_total) + + model.zero_grad() + train_iterator = trange(args.num_train_epochs, desc="Epoch", disable=True) + + global_step = 0 + tr_loss = 0.0 + for _ in train_iterator: + epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True) + for step, batch in enumerate(epoch_iterator): + source, target = batch + token_type_ids = compute_token_type_ids(source, tokenizer.cls_token_id) + labels_src = mask_padding_tokens(source) + labels_tgt = mask_padding_tokens(target) + + source = source.to(args.device) + target = target.to(args.device) + token_type_ids = token_type_ids.to(args.device) + labels_src = labels_src.to(args.device) + labels_tgt = labels_tgt.to(args.device) + + model.train() + outputs = model( + source, + target, + token_type_ids=token_type_ids, + decoder_encoder_attention_mask=labels_src, + decoder_attention_mask=labels_tgt, + decoder_lm_labels=labels_tgt, + decoder_initialize_randomly=True, + ) + + loss = outputs[0] + print(loss) + if args.gradient_accumulation_steps > 1: + loss /= args.gradient_accumulation_steps + + loss.backward() + + tr_loss += loss.item() + if (step + 1) % args.gradient_accumulation_steps == 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + optimizer.step() + model.zero_grad() + global_step += 1 + + if args.max_steps > 0 and global_step > args.max_steps: + epoch_iterator.close() + break + + if args.max_steps > 0 and global_step > args.max_steps: + train_iterator.close() + break + + return global_step, tr_loss / global_step + + +# ------------ +# Train +# ------------ + + +def evaluate(args, model, tokenizer, prefix=""): + set_seed(args) + + args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) + eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True) + eval_sampler = SequentialSampler(eval_dataset) + eval_dataloader = DataLoader( + eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size + ) + + logger.info("***** Running evaluation {} *****".format(prefix)) + logger.info(" Num examples = %d", len(eval_dataset)) + logger.info(" Batch size = %d", args.eval_batch_size) + eval_loss = 0.0 + nb_eval_steps = 0 + model.eval() + + for batch in tqdm(eval_dataloader, desc="Evaluating"): + source, target = batch + labels_src = mask_padding_tokens(source) + labels_tgt = mask_padding_tokens(target) + source.to(args.device) + target.to(args.device) + labels_src.to(args.device) + labels_tgt.to(args.device) + + with torch.no_grad(): + outputs = model( + source, + target, + decoder_encoder_attention_mask=labels_src, + decoder_attention_mask=labels_tgt, + decoder_lm_labels=labels_tgt, + ) + lm_loss = outputs[0] + eval_loss += lm_loss.mean().item() + nb_eval_steps += 1 + + eval_loss = eval_loss / nb_eval_steps + perplexity = torch.exp(torch.tensor(eval_loss)) + + result = {"perplexity": perplexity} + + # Save the evaluation's results + output_eval_file = os.path.join(args.output_dir, "eval_results.txt") + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + with open(output_eval_file, "w") as writer: + logger.info("***** Eval results {} *****".format(prefix)) + for key in sorted(result.keys()): + logger.info(" %s = %s", key, str(result[key])) + writer.write("%s = %s\n" % (key, str(result[key]))) + + return result + + +def main(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--data_dir", + default=None, + type=str, + required=True, + help="The input training data file (a text file).", + ) + parser.add_argument( + "--output_dir", + default=None, + type=str, + required=True, + help="The output directory where the model predictions and checkpoints will be written.", + ) + + # Optional parameters + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--do_evaluate", + type=bool, + default=False, + help="Run model evaluation on out-of-sample data.", + ) + parser.add_argument("--do_train", type=bool, default=False, help="Run training.") + parser.add_argument( + "--do_overwrite_output_dir", + type=bool, + default=False, + help="Whether to overwrite the output dir.", + ) + parser.add_argument( + "--model_name_or_path", + default="bert-base-cased", + type=str, + help="The model checkpoint to initialize the encoder and decoder's weights with.", + ) + parser.add_argument( + "--model_type", + default="bert", + type=str, + help="The decoder architecture to be fine-tuned.", + ) + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." + ) + parser.add_argument( + "--max_steps", + default=-1, + type=int, + help="If > 0: set total number of training steps to perform. Override num_train_epochs.", + ) + parser.add_argument( + "--to_cpu", default=False, type=bool, help="Whether to force training on CPU." + ) + parser.add_argument( + "--num_train_epochs", + default=1, + type=int, + help="Total number of training epochs to perform.", + ) + parser.add_argument( + "--per_gpu_train_batch_size", + default=4, + type=int, + help="Batch size per GPU/CPU for training.", + ) + parser.add_argument("--seed", default=42, type=int) + args = parser.parse_args() + + if ( + os.path.exists(args.output_dir) + and os.listdir(args.output_dir) + and args.do_train + and not args.do_overwrite_output_dir + ): + raise ValueError( + "Output directory ({}) already exists and is not empty. Use --do_overwrite_output_dir to overwrite.".format( + args.output_dir + ) + ) + + # Set up training device + if args.to_cpu or not torch.cuda.is_available(): + args.device = torch.device("cpu") + args.n_gpu = 0 + else: + args.device = torch.device("cuda") + args.n_gpu = torch.cuda.device_count() + + # Load pretrained model and tokenizer + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) + model = Model2Model.from_pretrained(args.model_name_or_path) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.warning( + "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", + 0, + args.device, + args.n_gpu, + False, + False, + ) + + logger.info("Training/evaluation parameters %s", args) + + # Train the model + model.to(args.device) + if args.do_train: + global_step, tr_loss = train(args, model, tokenizer) + logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + logger.info("Saving model checkpoint to %s", args.output_dir) + + # Save a trained model, configuration and tokenizer using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + model_to_save = ( + model.module if hasattr(model, "module") else model + ) # Take care of distributed/parallel training + model_to_save.save_pretrained(args.output_dir) + tokenizer.save_pretrained(args.output_dir) + torch.save(args, os.path.join(args.output_dir, "training_arguments.bin")) + + # Evaluate the model + results = {} + if args.do_evaluate: + checkpoints = [] + logger.info("Evaluate the following checkpoints: %s", checkpoints) + for checkpoint in checkpoints: + encoder_checkpoint = os.path.join(checkpoint, "encoder") + decoder_checkpoint = os.path.join(checkpoint, "decoder") + model = PreTrainedSeq2seq.from_pretrained( + encoder_checkpoint, decoder_checkpoint + ) + model.to(args.device) + results = "placeholder" + + return results + + +if __name__ == "__main__": + main() diff --git a/examples/run_seq2seq_finetuning_test.py b/examples/run_summarization_finetuning_test.py similarity index 79% rename from examples/run_seq2seq_finetuning_test.py rename to examples/run_summarization_finetuning_test.py index 77dc58666c..fd997ee0c2 100644 --- a/examples/run_seq2seq_finetuning_test.py +++ b/examples/run_summarization_finetuning_test.py @@ -14,7 +14,7 @@ # limitations under the License. import unittest -from run_seq2seq_finetuning import _fit_to_block_size, process_story +from run_summarization_finetuning import _fit_to_block_size, process_story class DataLoaderTest(unittest.TestCase): @@ -43,15 +43,16 @@ class DataLoaderTest(unittest.TestCase): raw_story = """It was the year of Our Lord one thousand seven hundred and seventy-five.\n\nSpiritual revelations were conceded to England at that favoured period, as at this.""" - with self.assertRaises(IndexError): - process_story(raw_story) + _, summary = process_story(raw_story) + self.assertEqual(summary, []) def test_process_empty_story(self): """ An empty story should also raise and exception. """ raw_story = "" - with self.assertRaises(IndexError): - process_story(raw_story) + story, summary = process_story(raw_story) + self.assertEqual(story, []) + self.assertEqual(summary, []) def test_story_with_missing_period(self): raw_story = ( @@ -59,17 +60,16 @@ class DataLoaderTest(unittest.TestCase): "seventy-five\n\nSpiritual revelations were conceded to England " "at that favoured period, as at this.\n@highlight\n\nIt was the best of times" ) - story, summary = process_story(raw_story) + story_lines, summary_lines = process_story(raw_story) - expected_story = ( - "It was the year of Our Lord one thousand seven hundred and " - "seventy-five. Spiritual revelations were conceded to England at that " - "favoured period, as at this." - ) - self.assertEqual(expected_story, story) + expected_story_lines = [ + "It was the year of Our Lord one thousand seven hundred and seventy-five.", + "Spiritual revelations were conceded to England at that favoured period, as at this.", + ] + self.assertEqual(expected_story_lines, story_lines) - expected_summary = "It was the best of times." - self.assertEqual(expected_summary, summary) + expected_summary_lines = ["It was the best of times."] + self.assertEqual(expected_summary_lines, summary_lines) if __name__ == "__main__": diff --git a/transformers/__init__.py b/transformers/__init__.py index ee8e812a23..2206a0302e 100644 --- a/transformers/__init__.py +++ b/transformers/__init__.py @@ -87,7 +87,7 @@ if is_torch_available(): from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel, DistilBertForSequenceClassification, DistilBertForQuestionAnswering, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) - from .modeling_seq2seq import Model2Model + from .modeling_seq2seq import PreTrainedSeq2seq, Model2Model # Optimization from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule, diff --git a/transformers/modeling_beam_search.py b/transformers/modeling_beam_search.py new file mode 100644 index 0000000000..3a27625f90 --- /dev/null +++ b/transformers/modeling_beam_search.py @@ -0,0 +1,240 @@ +# coding=utf-8 +# Copyright (c) 2019 Yang Liu + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +A general wrapper around models with LM heads to generate sequences +using beam search. +""" +import torch +from torch import nn + + +class ModelWithBeamSearch(nn.Module): + def __init__( + self, + model, + beam_size, + start_token_id, + end_token_id, + pad_token_id, + min_length, + max_length, + alpha, + block_trigram=True, + ): + """ + Attributes: + mask_word_id: token id that corresponds to the mask + """ + super(ModelWithBeamSearch, self).__init__() + self.model = model + self.beam_size = beam_size + self.start_token_id = start_token_id + self.end_token_id = end_token_id + self.pad_token_id = pad_token_id + self.min_length = min_length + self.max_length = max_length + self.alpha = alpha + self.block_trigram = block_trigram + + def forward(self, input_ids, **kwargs): + # Separate the encoder- and decoder- specific kwargs. A kwarg is + # decoder-specific it the key starts with `decoder_` + kwargs_encoder = { + argument: value + for argument, value in kwargs.items() + if not argument.startswith("decoder_") + } + kwargs_decoder = { + argument[len("decoder_"):]: value + for argument, value in kwargs.items() + if argument.startswith("decoder_") + } + + batch_size, _ = input_ids.size(0) + + # Variables that keep track of the status of the search + hypotheses = [[] for _ in range(batch_size)] + batch_offset = torch.arange(batch_size, dtype=torch.long) + beam_offset = torch.arange( + 0, + batch_size * self.beam_size, + step=self.beam_size, + dtype=torch.long, + ) + growing_beam = torch.full( + (batch_size * self.beam_size, 1), + self.start_token_id, + dtype=torch.long, + ) + topk_log_probabilities = torch.tensor( + [0.0] + [float("-inf")] * (self.beam_size - 1), + dtype=torch.float, + ).repeat(batch_size) + + # Forward pass on the encoder + encoder_outputs = self.encoder(input_ids, kwargs_encoder) + kwargs_decoder["encoder_hidden_states"] = tile( + encoder_outputs, self.beam_size, dim=0 + ) + + results = {} + results["predictions"] = [[] for _ in batch_size] + results["scores"] = [[] for _ in batch_size] + + for step in range(self.max_length): + decoder_input = growing_beam[:, -1] + outputs = self.decoder(decoder_input, kwargs_decoder) + log_probabilities = torch.nn.functional.log_softmax(outputs[1]) + vocab_size = log_probabilities.size(-1) + + # The batch size changes as some beams finish so we define: + _B = log_probabilities.size(0) // self.beam_size + + # Multiply each beam probability with the probability of the + # next token (conditioned on the words in the beam). + log_probabilities += topk_log_probabilities.view(-1, 1) + + # if the beam has not attained the minimum required length we + # make the end token arbitrarily unlikely. + if step < self.min_length: + log_probabilities[self.end_token_id] = -1e20 + + # Remove repeating tri-grams + if(self.args.block_trigram): + if(step + 1 > 3): + for i in range(_B * self.beam_size): + tokens = [t for t in growing_beam[i]] + trigrams = [(tokens[i-1], tokens[i], tokens[i+1]) for i in range(1, len(words) - 1)] + last_trigram = tuple(trigrams[-1]) + if last_trigram in trigrams[:-1]: + log_probabilities[i] = -1e20 + + # Find the `beam_size` (previous_beam + token) combinations with + # the highest score + topk_log_probabilities, topk_ids = log_probabilities.topk( + log_probabilities.view(_B, self.beam_size * vocab_size), + self.beam_size, + dim=1 + ) + + # Apply the length penalty. The +1 accounts for the [EOS] token + # that will be added if the beam ends. + length_penalty = ((5.0 + (step + 1)) / 6.0) ** self.alpha + topk_scores = topk_log_probabilities / length_penalty + + # Retrieve the corresponding respective beam and token id + # topk_token_ids[i] will be added to topk_beam_ids[i] + topk_beam_ids = topk_ids.div(vocab_size) + topk_token_ids = topk_ids.fmod(vocab_size) + + # Retrieve the row index of the surviving beams in the original + # view of the log_probabilities tensor + surviving_beams_rows = ( + topk_beam_ids + beam_offset[:_B].view(-1, 1) + ).view(-1) + + # Append the last predictions + growing_beam = torch.cat( + [ + growing_beam.index_select(0, surviving_beams_rows), + topk_token_ids.view(-1, 1), + ], + 1, + ) + + # Check if any of the beam searches has ended during this + # growth step. Also if top beam (most probable) has ended + # for one element of the batch. + is_finished = topk_token_ids.eq(self.end_token_id) + if step + 1 == self.max_length: + is_finished.fill_(1) + is_top_beam_finished = is_finished[:, 0].eq(1) + + # Save the finished searches + if is_finished.any(): + predictions = growing_beam.view(-1, self.beam_size, growing_beam.size(1)) + for i in range(is_finished.size(0)): + if is_top_beam_finished[i]: + is_finished[i].fill_(1) + finished_hyp = is_finished[i].nonzero().view(-1) + + # Store finished hypotheses for this batch. + b = batch_offset[i] + for j in finished_hyp: + hypotheses[b].append((topk_scores[i, j], predictions[i, j, :])) + + # If the batch reached the end, save the best hypotheses + # in terms of length-penalized score. + if is_top_beam_finished[i]: + best_hyp = sorted( + hypotheses[b], key=lambda x: x[0], reverse=True + ) + best_score, best_prediction = best_hyp[0] + results["scores"][b].append(best_score) + results["predictions"][b].append(best_prediction) + + non_finished = is_top_beam_finished.eq(0).nonzero().view(-1) + if len(non_finished) == 0: + break + + # Remove finished batches for the next step. + topk_log_probabilities = topk_log_probabilities.index_select(0, non_finished) + batch_offset = batch_offset.index_select(0, non_finished) + growing_beam = predictions.index_select(0, non_finished).view( + -1, growing_beam.size(-1) + ) + + # Re-order the state for the next pass + surviving_beams_rows = surviving_beams_rows.index_select(0, non_finished) + kwargs_decoder["encoder_hidden_states"] = kwargs_decoder[ + "encoder_hidden_states" + ].index_select(0, surviving_beams_rows) + + return results + + +def tile(x, count, dim=0): + """ + Tiles `x` along dimension `dim` `count` times. + + Example: + >> ex = torch.tensor([1,2],[3,4]) + >> tile(ex, 2, 0) + torch.Tensor([[1,2],[1,2],[3,4],[3,4]]) + """ + perm = list(range(len(x.size()))) + if dim != 0: + perm[0], perm[dim] = perm[dim], perm[0] + x = x.permute(perm).contiguous() + out_size = list(x.size()) + out_size[0] *= count + batch = x.size(0) + x = ( + x.view(batch, -1) + .transpose(0, 1) + .repeat(count, 1) + .transpose(0, 1) + .contiguous() + .view(*out_size) + ) + if dim != 0: + x = x.permute(perm).contiguous() + return x diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index d10f32c1fa..93f3c7e1f1 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -646,7 +646,7 @@ class BertModel(BertPreTrainedModel): if attention_mask.dim() == 2: if self.config.is_decoder: batch_size, seq_length = input_ids.size() - seq_ids = torch.arange(seq_length) + seq_ids = torch.arange(seq_length, device=input_ids.device) causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] else: @@ -660,6 +660,13 @@ class BertModel(BertPreTrainedModel): extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + # If a 2D encoder attention mask is provided for the cross-attention + # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] + if encoder_attention_mask is not None: + encoder_attention_mask = encoder_attention_mask[:, None, None, :] + encoder_attention_mask = encoder_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + encoder_attention_mask = (1.0 - encoder_attention_mask) * -10000.0 + # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N @@ -819,7 +826,7 @@ class BertForMaskedLM(BertPreTrainedModel): self.bert.embeddings.word_embeddings) def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, - masked_lm_labels=None, lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None): + masked_lm_labels=None, encoder_hidden_states=None, encoder_attention_mask=None, lm_labels=None, ): outputs = self.bert(input_ids, attention_mask=attention_mask, @@ -838,11 +845,8 @@ class BertForMaskedLM(BertPreTrainedModel): # 1. If a tensor that contains the indices of masked labels is provided, # the cross-entropy is the MLM cross-entropy that measures the likelihood # of predictions for masked words. - # 2. If encoder hidden states are provided we are in a causal situation where we + # 2. If `lm_label` is provided we are in a causal scenario where we # try to predict the next word for each input in the encoder. - if masked_lm_labels is not None and lm_labels is not None: - raise AttributeError("Masked LM training with an encoder-decoder is not supported.") - if masked_lm_labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-1) # -1 index = padding token masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) @@ -851,9 +855,9 @@ class BertForMaskedLM(BertPreTrainedModel): if lm_labels is not None: # we are doing next-token prediction; shift prediction scores and input ids by one prediction_scores = prediction_scores[:, :-1, :] - lm_labels = lm_labels[:, 1:, :] + lm_labels = lm_labels[:, 1:] loss_fct = CrossEntropyLoss(ignore_index=-1) - seq2seq_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1)) + seq2seq_loss = loss_fct(prediction_scores.reshape(-1, self.config.vocab_size), lm_labels.reshape(-1)) outputs = (seq2seq_loss,) + outputs return outputs # (mlm_or_seq2seq_loss), prediction_scores, (hidden_states), (attentions) diff --git a/transformers/modeling_seq2seq.py b/transformers/modeling_seq2seq.py index 108fdaa853..2767dd2cd1 100644 --- a/transformers/modeling_seq2seq.py +++ b/transformers/modeling_seq2seq.py @@ -17,13 +17,12 @@ from __future__ import absolute_import, division, print_function, unicode_literals import logging +import os import torch from torch import nn -from .file_utils import add_start_docstrings from .modeling_auto import AutoModel, AutoModelWithLMHead -from .modeling_utils import PreTrainedModel, SequenceSummary logger = logging.getLogger(__name__) @@ -43,7 +42,13 @@ class PreTrainedSeq2seq(nn.Module): self.decoder = decoder @classmethod - def from_pretrained(cls, encoder_pretrained_model_name_or_path=None, decoder_pretrained_model_name_or_path=None, *model_args, **kwargs): + def from_pretrained( + cls, + encoder_pretrained_model_name_or_path=None, + decoder_pretrained_model_name_or_path=None, + *model_args, + **kwargs + ): r""" Instantiates an encoder and a decoder from one or two base classes of the library from pre-trained model checkpoints. @@ -108,23 +113,28 @@ class PreTrainedSeq2seq(nn.Module): # Separate the encoder- and decoder- specific kwargs. A kwarg is # decoder-specific it the key starts with `decoder_` - kwargs_decoder = {} - kwargs_encoder = kwargs - for key in kwargs_encoder.keys(): - if key.startswith("decoder_"): - kwargs_decoder[key.replace("decoder_", "")] = kwargs_encoder.pop(key) + kwargs_encoder = { + argument: value + for argument, value in kwargs.items() + if not argument.startswith("decoder_") + } + kwargs_decoder = { + argument[len("decoder_") :]: value + for argument, value in kwargs.items() + if argument.startswith("decoder_") + } # Load and initialize the encoder and decoder - # The distinction between encoder and decoder at the model level is made - # by the value of the flag `is_decoder` that we need to set correctly. - encoder = kwargs.pop("encoder_model", None) + # The distinction between encoder and decoder at the model level is made + # by the value of the flag `is_decoder` that we need to set correctly. + encoder = kwargs_encoder.pop("encoder_model", None) if encoder is None: kwargs_encoder["is_decoder"] = False encoder = AutoModel.from_pretrained( encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder ) - decoder = kwargs.pop("decoder_model", None) + decoder = kwargs_decoder.pop("model", None) if decoder is None: kwargs_decoder["is_decoder"] = True decoder = AutoModelWithLMHead.from_pretrained( @@ -135,6 +145,12 @@ class PreTrainedSeq2seq(nn.Module): return model + def save_pretrained(self, save_directory): + """ Save a Seq2Seq model and its configuration file in a format + such that it can be loaded using `:func:`~transformers.PreTrainedSeq2seq.from_pretrained` """ + self.encoder.save_pretrained(os.path.join(save_directory, "encoder")) + self.decoder.save_pretrained(os.path.join(save_directory, "decoder")) + def forward(self, encoder_input_ids, decoder_input_ids, **kwargs): """ The forward pass on a seq2eq depends what we are performing: @@ -155,22 +171,29 @@ class PreTrainedSeq2seq(nn.Module): """ # Separate the encoder- and decoder- specific kwargs. A kwarg is # decoder-specific it the key starts with `decoder_` - kwargs_decoder = {} - kwargs_encoder = kwargs - for key in kwargs_encoder.keys(): - if key.startswith("decoder_"): - kwargs_decoder[key.replace("decoder_", "")] = kwargs_encoder.pop(key) + kwargs_encoder = { + argument: value + for argument, value in kwargs.items() + if not argument.startswith("decoder_") + } + kwargs_decoder = { + argument[len("decoder_") :]: value + for argument, value in kwargs.items() + if argument.startswith("decoder_") + } # Encode if needed (training, first prediction pass) encoder_hidden_states = kwargs_encoder.pop("encoder_hidden_states", None) if encoder_hidden_states is None: encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder) - encoder_hidden_states = encoder_outputs[0][-1] # output of the encoder *stack* + encoder_hidden_states = encoder_outputs[0][ + -1 + ] # output of the encoder *stack* else: encoder_outputs = () # Decode - kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states + kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states[None, :, :] decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder) return decoder_outputs + encoder_outputs @@ -201,9 +224,25 @@ class Model2Model(PreTrainedSeq2seq): @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): - model = super(Model2Model, cls).from_pretrained(encoder_pretrained_model_name_or_path=pretrained_model_name_or_path, - decoder_pretrained_model_name_or_path=pretrained_model_name_or_path, - **kwargs) + + if ( + "bert" not in pretrained_model_name_or_path + or "roberta" in pretrained_model_name_or_path + or "distilbert" in pretrained_model_name_or_path + ): + raise ValueError("Only the Bert model is currently supported.") + + model = super(Model2Model, cls).from_pretrained( + encoder_pretrained_model_name_or_path=pretrained_model_name_or_path, + decoder_pretrained_model_name_or_path=pretrained_model_name_or_path, + **kwargs + ) + + # Some architectures require for the decoder to be initialized randomly + # before fine-tuning. + if kwargs.get("decoder_initialize_randomly", False): + model.decoder.init_weights() + return model