From dfce40969141eb037e8af3ed64e490a876386bf5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Tue, 29 Oct 2019 17:10:20 +0100 Subject: [PATCH] resolve PR comments --- examples/run_summarization_finetuning.py | 292 +++++----------- examples/run_summarization_finetuning_test.py | 76 ---- examples/utils_summarization.py | 184 ++++++++++ examples/utils_summarization_test.py | 133 +++++++ transformers/modeling_beam_search.py | 325 ++++++++++-------- transformers/modeling_bert.py | 31 +- transformers/modeling_seq2seq.py | 79 +++-- 7 files changed, 647 insertions(+), 473 deletions(-) delete mode 100644 examples/run_summarization_finetuning_test.py create mode 100644 examples/utils_summarization.py create mode 100644 examples/utils_summarization_test.py diff --git a/examples/run_summarization_finetuning.py b/examples/run_summarization_finetuning.py index 64bee82c5b..1888f56caf 100644 --- a/examples/run_summarization_finetuning.py +++ b/examples/run_summarization_finetuning.py @@ -16,10 +16,9 @@ """ Finetuning seq2seq models for sequence generation.""" import argparse -from collections import deque +import functools import logging import os -import pickle import random import sys @@ -29,7 +28,22 @@ import torch from torch.optim import Adam from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler -from transformers import AutoTokenizer, PreTrainedSeq2seq, Model2Model +from transformers import ( + AutoTokenizer, + BertForMaskedLM, + BertConfig, + PreTrainedSeq2seq, + Model2Model, +) + +from utils_summarization import ( + CNNDailyMailDataset, + encode_for_summarization, + fit_to_block_size, + build_lm_labels, + build_mask, + compute_token_type_ids, +) logger = logging.getLogger(__name__) logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -46,194 +60,41 @@ def set_seed(args): # ------------ -class TextDataset(Dataset): - """ Abstracts the dataset used to train seq2seq models. - - CNN/Daily News: - - The CNN/Daily News raw datasets are downloaded from [1]. The stories are - stored in different files; the summary appears at the end of the story as - sentences that are prefixed by the special `@highlight` line. To process - the data, untar both datasets in the same folder, and pass the path to this - folder as the "data_dir argument. The formatting code was inspired by [2]. - - [1] https://cs.nyu.edu/~kcho/ - [2] https://github.com/abisee/cnn-dailymail/ - """ - - def __init__(self, tokenizer, prefix="train", data_dir="", block_size=512): - assert os.path.isdir(data_dir) - - # Load the features that have already been computed, if any - cached_features_file = os.path.join( - data_dir, "cached_lm_{}_{}".format(block_size, prefix) - ) - if os.path.exists(cached_features_file): - logger.info("Loading features from cached file %s", cached_features_file) - with open(cached_features_file, "rb") as source: - self.examples = pickle.load(source) - return - - logger.info("Creating features from dataset at %s", data_dir) - datasets = ["cnn", "dailymail"] - - self.examples = {"source": [], "target": []} - for dataset in datasets: - path_to_stories = os.path.join(data_dir, dataset, "stories") - story_filenames_list = os.listdir(path_to_stories) - for story_filename in story_filenames_list: - path_to_story = os.path.join(path_to_stories, story_filename) - if not os.path.isfile(path_to_story): - continue - - with open(path_to_story, encoding="utf-8") as source: - raw_story = source.read() - story_lines, summary_lines = process_story(raw_story) - if len(summary_lines) == 0 or len(story_lines) == 0: - continue - - story_token_ids, summary_token_ids = _encode_for_summarization( - story_lines, summary_lines, tokenizer - ) - story_seq = _fit_to_block_size(story_token_ids, block_size) - self.examples["source"].append(story_seq) - - summary_seq = _fit_to_block_size(summary_token_ids, block_size) - self.examples["summary"].append(summary_seq) - - logger.info("Saving features into cache file %s", cached_features_file) - with open(cached_features_file, "wb") as sink: - pickle.dump(self.examples, sink, protocol=pickle.HIGHEST_PROTOCOL) - - def __len__(self): - return len(self.examples) - - def __getitem__(self, items): - return ( - torch.tensor(self.examples["source"][items]), - torch.tensor(self.examples["target"][items]), - ) - - -def process_story(raw_story): - """ Extract the story and summary from a story file. - - Attributes: - raw_story (str): content of the story file as an utf-8 encoded string. - - Raises: - IndexError: If the stoy is empty or contains no highlights. - """ - nonempty_lines = list( - filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]) - ) - - # for some unknown reason some lines miss a period, add it - nonempty_lines = [_add_missing_period(line) for line in nonempty_lines] - - # gather article lines - story_lines = [] - lines = deque(nonempty_lines) - while True: - try: - element = lines.popleft() - if element.startswith("@highlight"): - break - story_lines.append(element) - except IndexError: - # if "@highlight" is absent from the file we pop - # all elements until there is None. - return story_lines, [] - - # gather summary lines - summary_lines = list(filter(lambda t: not t.startswith("@highlight"), lines)) - - return story_lines, summary_lines - - -def _encode_for_summarization(story_lines, summary_lines, tokenizer): - """ Encode the story and summary lines, and join them - as specified in [1] by using `[SEP] [CLS]` tokens to separate - sentences. - """ - story_lines_token_ids = [ - tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line)) - for line in story_lines - ] - summary_lines_token_ids = [ - tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line)) - for line in summary_lines - ] - - story_token_ids = [ - token for sentence in story_lines_token_ids for token in sentence - ] - summary_token_ids = [ - token for sentence in summary_lines_token_ids for token in sentence - ] - - return story_token_ids, summary_token_ids - - -def _add_missing_period(line): - END_TOKENS = [".", "!", "?", "...", "'", "`", '"', u"\u2019", u"\u2019", ")"] - if line.startswith("@highlight"): - return line - if line[-1] in END_TOKENS: - return line - return line + "." - - -def _fit_to_block_size(sequence, block_size): - """ Adapt the source and target sequences' lengths to the block size. - If the sequence is shorter than the block size we pad it with -1 ids - which correspond to padding tokens. - """ - if len(sequence) > block_size: - return sequence[:block_size] - else: - sequence.extend([0] * (block_size - len(sequence))) - return sequence - - -def mask_padding_tokens(sequence): - """ Padding token, encoded as 0, are represented by the value -1 in the - masks """ - padded = sequence.clone() - padded[padded == 0] = -1 - return padded - - def load_and_cache_examples(args, tokenizer): - dataset = TextDataset(tokenizer, data_dir=args.data_dir) + dataset = CNNDailyMailDataset(tokenizer, data_dir=args.data_dir) return dataset -def compute_token_type_ids(batch, separator_token_id): - """ Segment embeddings as described in [1] +def collate(data, tokenizer, block_size): + """ List of tuple as an input. """ + # remove the files with empty an story/summary, encode and fit to block + data = filter(lambda x: not (len(x[0]) == 0 or len(x[1]) == 0), data) + data = [ + encode_for_summarization(story, summary, tokenizer) for story, summary in data + ] + data = [ + ( + fit_to_block_size(story, block_size, tokenizer.pad_token_id), + fit_to_block_size(summary, block_size, tokenizer.pad_token_id), + ) + for story, summary in data + ] - The values {0,1} were found in the repository [2]. + stories = torch.tensor([story for story, summary in data]) + summaries = torch.tensor([summary for story, summary in data]) + encoder_token_type_ids = compute_token_type_ids(stories, tokenizer.cls_token_id) + encoder_mask = build_mask(stories, tokenizer.pad_token_id) + decoder_mask = build_mask(summaries, tokenizer.pad_token_id) + lm_labels = build_lm_labels(summaries, tokenizer.pad_token_id) - Attributes: - batch: torch.Tensor, size [batch_size, block_size] - Batch of input. - separator_token_id: int - The value of the token that separates the segments. - - [1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders." - arXiv preprint arXiv:1908.08345 (2019). - [2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217) - """ - batch_embeddings = [] - sentence_num = 0 - for sequence in batch: - embeddings = [] - for s in sequence: - if s == separator_token_id: - sentence_num += 1 - embeddings.append(sentence_num % 2) - batch_embeddings.append(embeddings) - return torch.tensor(batch_embeddings) + return ( + stories, + summaries, + encoder_token_type_ids, + encoder_mask, + decoder_mask, + lm_labels, + ) # ---------- @@ -252,7 +113,7 @@ class BertSumOptimizer(object): arXiv preprint arXiv:1908.08345 (2019). """ - def __init__(self, model, lr, warmup_steps, beta_1=0.99, beta_2=0.999, eps=1e-9): + def __init__(self, model, lr, warmup_steps, beta_1=0.99, beta_2=0.999, eps=1e-8): self.encoder = model.encoder self.decoder = model.decoder self.lr = lr @@ -306,8 +167,12 @@ def train(args, model, tokenizer): args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_dataset = load_and_cache_examples(args, tokenizer) train_sampler = RandomSampler(train_dataset) + model_collate_fn = functools.partial(collate, tokenizer=tokenizer, block_size=512) train_dataloader = DataLoader( - train_dataset, sampler=train_sampler, batch_size=args.train_batch_size + train_dataset, + sampler=train_sampler, + batch_size=args.train_batch_size, + collate_fn=model_collate_fn, ) # Training schedule @@ -351,26 +216,23 @@ def train(args, model, tokenizer): for _ in train_iterator: epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True) for step, batch in enumerate(epoch_iterator): - source, target = batch - token_type_ids = compute_token_type_ids(source, tokenizer.cls_token_id) - labels_src = mask_padding_tokens(source) - labels_tgt = mask_padding_tokens(target) + source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch source = source.to(args.device) target = target.to(args.device) - token_type_ids = token_type_ids.to(args.device) - labels_src = labels_src.to(args.device) - labels_tgt = labels_tgt.to(args.device) + encoder_token_type_ids = encoder_token_type_ids.to(args.device) + encoder_mask = encoder_mask.to(args.device) + decoder_mask = decoder_mask.to(args.device) + lm_labels = lm_labels.to(args.device) model.train() outputs = model( source, target, - token_type_ids=token_type_ids, - decoder_encoder_attention_mask=labels_src, - decoder_attention_mask=labels_tgt, - decoder_lm_labels=labels_tgt, - decoder_initialize_randomly=True, + encoder_token_type_ids=encoder_token_type_ids, + encoder_attention_mask=encoder_mask, + decoder_attention_mask=decoder_mask, + decoder_lm_labels=lm_labels, ) loss = outputs[0] @@ -421,21 +283,23 @@ def evaluate(args, model, tokenizer, prefix=""): model.eval() for batch in tqdm(eval_dataloader, desc="Evaluating"): - source, target = batch - labels_src = mask_padding_tokens(source) - labels_tgt = mask_padding_tokens(target) - source.to(args.device) - target.to(args.device) - labels_src.to(args.device) - labels_tgt.to(args.device) + source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch + + source = source.to(args.device) + target = target.to(args.device) + encoder_token_type_ids = encoder_token_type_ids.to(args.device) + encoder_mask = encoder_mask.to(args.device) + decoder_mask = decoder_mask.to(args.device) + lm_labels = lm_labels.to(args.device) with torch.no_grad(): outputs = model( source, target, - decoder_encoder_attention_mask=labels_src, - decoder_attention_mask=labels_tgt, - decoder_lm_labels=labels_tgt, + encoder_token_type_ids=encoder_token_type_ids, + encoder_attention_mask=encoder_mask, + decoder_attention_mask=decoder_mask, + decoder_lm_labels=lm_labels, ) lm_loss = outputs[0] eval_loss += lm_loss.mean().item() @@ -525,7 +389,7 @@ def main(): ) parser.add_argument( "--num_train_epochs", - default=1, + default=10, type=int, help="Total number of training epochs to perform.", ) @@ -558,9 +422,13 @@ def main(): args.device = torch.device("cuda") args.n_gpu = torch.cuda.device_count() - # Load pretrained model and tokenizer + # Load pretrained model and tokenizer. The decoder's weights are randomly initialized. tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) - model = Model2Model.from_pretrained(args.model_name_or_path) + config = BertConfig.from_pretrained(args.model_name_or_path) + decoder_model = BertForMaskedLM(config) + model = Model2Model.from_pretrained( + args.model_name_or_path, decoder_model=decoder_model + ) # Setup logging logging.basicConfig( diff --git a/examples/run_summarization_finetuning_test.py b/examples/run_summarization_finetuning_test.py deleted file mode 100644 index fd997ee0c2..0000000000 --- a/examples/run_summarization_finetuning_test.py +++ /dev/null @@ -1,76 +0,0 @@ -# coding=utf-8 -# Copyright 2019 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import unittest - -from run_summarization_finetuning import _fit_to_block_size, process_story - - -class DataLoaderTest(unittest.TestCase): - def setUp(self): - self.block_size = 10 - - def test_truncate_sequence_too_small(self): - """ Pad the sequence with 0 if the sequence is smaller than the block size.""" - sequence = [1, 2, 3, 4] - expected_output = [1, 2, 3, 4, 0, 0, 0, 0, 0, 0] - self.assertEqual(_fit_to_block_size(sequence, self.block_size), expected_output) - - def test_truncate_sequence_fit_exactly(self): - sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - self.assertEqual(_fit_to_block_size(sequence, self.block_size), expected_output) - - def test_truncate_sequence_too_big(self): - sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] - expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - self.assertEqual(_fit_to_block_size(sequence, self.block_size), expected_output) - - def test_process_story_no_highlights(self): - """ Processing a story with no highlights should raise an exception. - """ - raw_story = """It was the year of Our Lord one thousand seven hundred and - seventy-five.\n\nSpiritual revelations were conceded to England at that - favoured period, as at this.""" - _, summary = process_story(raw_story) - self.assertEqual(summary, []) - - def test_process_empty_story(self): - """ An empty story should also raise and exception. - """ - raw_story = "" - story, summary = process_story(raw_story) - self.assertEqual(story, []) - self.assertEqual(summary, []) - - def test_story_with_missing_period(self): - raw_story = ( - "It was the year of Our Lord one thousand seven hundred and " - "seventy-five\n\nSpiritual revelations were conceded to England " - "at that favoured period, as at this.\n@highlight\n\nIt was the best of times" - ) - story_lines, summary_lines = process_story(raw_story) - - expected_story_lines = [ - "It was the year of Our Lord one thousand seven hundred and seventy-five.", - "Spiritual revelations were conceded to England at that favoured period, as at this.", - ] - self.assertEqual(expected_story_lines, story_lines) - - expected_summary_lines = ["It was the best of times."] - self.assertEqual(expected_summary_lines, summary_lines) - - -if __name__ == "__main__": - unittest.main() diff --git a/examples/utils_summarization.py b/examples/utils_summarization.py new file mode 100644 index 0000000000..cd8bc4bc2b --- /dev/null +++ b/examples/utils_summarization.py @@ -0,0 +1,184 @@ +from collections import deque +import os + +import torch +from torch.utils.data import Dataset + + +# ------------ +# Data loading +# ------------ + + +class CNNDailyMailDataset(Dataset): + """ Abstracts the dataset used to train seq2seq models. + + CNN/Daily News: + + The CNN/Daily News raw datasets are downloaded from [1]. The stories are + stored in different files; the summary appears at the end of the story as + sentences that are prefixed by the special `@highlight` line. To process + the data, untar both datasets in the same folder, and pass the path to this + folder as the "data_dir argument. The formatting code was inspired by [2]. + + [1] https://cs.nyu.edu/~kcho/ + [2] https://github.com/abisee/cnn-dailymail/ + """ + + def __init__(self, tokenizer, prefix="train", data_dir=""): + assert os.path.isdir(data_dir) + self.tokenizer = tokenizer + + # We initialize the class by listing all the files that contain + # stories and summaries. Files are not read in memory given + # the size of the corpus. + self.stories_path = [] + datasets = ("cnn", "dailymail") + for dataset in datasets: + path_to_stories = os.path.join(data_dir, dataset, "stories") + story_filenames_list = os.listdir(path_to_stories) + for story_filename in story_filenames_list: + path_to_story = os.path.join(path_to_stories, story_filename) + if not os.path.isfile(path_to_story): + continue + self.stories_path.append(path_to_story) + + def __len__(self): + return len(self.stories_path) + + def __getitem__(self, idx): + story_path = self.stories_path[idx] + with open(story_path, encoding="utf-8") as source: + raw_story = source.read() + story_lines, summary_lines = process_story(raw_story) + return story_lines, summary_lines + + +def process_story(raw_story): + """ Extract the story and summary from a story file. + + Attributes: + raw_story (str): content of the story file as an utf-8 encoded string. + + Raises: + IndexError: If the stoy is empty or contains no highlights. + """ + nonempty_lines = list( + filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]) + ) + + # for some unknown reason some lines miss a period, add it + nonempty_lines = [_add_missing_period(line) for line in nonempty_lines] + + # gather article lines + story_lines = [] + lines = deque(nonempty_lines) + while True: + try: + element = lines.popleft() + if element.startswith("@highlight"): + break + story_lines.append(element) + except IndexError: + # if "@highlight" is absent from the file we pop + # all elements until there is None. + return story_lines, [] + + # gather summary lines + summary_lines = list(filter(lambda t: not t.startswith("@highlight"), lines)) + + return story_lines, summary_lines + + +def _add_missing_period(line): + END_TOKENS = [".", "!", "?", "...", "'", "`", '"', u"\u2019", u"\u2019", ")"] + if line.startswith("@highlight"): + return line + if line[-1] in END_TOKENS: + return line + return line + "." + + +# -------------------------- +# Encoding and preprocessing +# -------------------------- + + +def fit_to_block_size(sequence, block_size, pad_token): + """ Adapt the source and target sequences' lengths to the block size. + If the sequence is shorter than the block size we pad it with -1 ids + which correspond to padding tokens. + """ + if len(sequence) > block_size: + return sequence[:block_size] + else: + sequence.extend([pad_token] * (block_size - len(sequence))) + return sequence + + +def build_lm_labels(sequence, pad_token): + """ Padding token, encoded as 0, are represented by the value -1 so they + are not taken into account in the loss computation. """ + padded = sequence.clone() + padded[padded == pad_token] = -1 + return padded + + +def build_mask(sequence, pad_token): + """ Builds the mask. The attention mechanism will only attend to positions + with value 1. """ + mask = sequence.clone() + mask[mask != pad_token] = 1 + mask[mask == pad_token] = 0 + return mask + + +def encode_for_summarization(story_lines, summary_lines, tokenizer): + """ Encode the story and summary lines, and join them + as specified in [1] by using `[SEP] [CLS]` tokens to separate + sentences. + """ + story_lines_token_ids = [ + tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line)) + for line in story_lines + ] + summary_lines_token_ids = [ + tokenizer.add_special_tokens_single_sequence(tokenizer.encode(line)) + for line in summary_lines + ] + + story_token_ids = [ + token for sentence in story_lines_token_ids for token in sentence + ] + summary_token_ids = [ + token for sentence in summary_lines_token_ids for token in sentence + ] + + return story_token_ids, summary_token_ids + + +def compute_token_type_ids(batch, separator_token_id): + """ Segment embeddings as described in [1] + + The values {0,1} were found in the repository [2]. + + Attributes: + batch: torch.Tensor, size [batch_size, block_size] + Batch of input. + separator_token_id: int + The value of the token that separates the segments. + + [1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders." + arXiv preprint arXiv:1908.08345 (2019). + [2] https://github.com/nlpyang/PreSumm (/src/prepro/data_builder.py, commit fac1217) + """ + batch_embeddings = [] + for sequence in batch: + sentence_num = 0 + embeddings = [] + for s in sequence: + if s == separator_token_id: + sentence_num += 1 + embeddings.append(sentence_num % 2) + batch_embeddings.append(embeddings) + return torch.tensor(batch_embeddings) diff --git a/examples/utils_summarization_test.py b/examples/utils_summarization_test.py new file mode 100644 index 0000000000..7a02f8fa1f --- /dev/null +++ b/examples/utils_summarization_test.py @@ -0,0 +1,133 @@ +# coding=utf-8 +# Copyright 2019 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import torch + +from utils_summarization import ( + compute_token_type_ids, + fit_to_block_size, + build_mask, + build_lm_labels, + process_story, +) + + +class SummarizationDataProcessingTest(unittest.TestCase): + def setUp(self): + self.block_size = 10 + + def test_fit_to_block_sequence_too_small(self): + """ Pad the sequence with 0 if the sequence is smaller than the block size.""" + sequence = [1, 2, 3, 4] + expected_output = [1, 2, 3, 4, 0, 0, 0, 0, 0, 0] + self.assertEqual( + fit_to_block_size(sequence, self.block_size, 0), expected_output + ) + + def test_fit_to_block_sequence_fit_exactly(self): + """ Do nothing if the sequence is the right size. """ + sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + self.assertEqual( + fit_to_block_size(sequence, self.block_size, 0), expected_output + ) + + def test_fit_to_block_sequence_too_big(self): + """ Truncate the sequence if it is too long. """ + sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + self.assertEqual( + fit_to_block_size(sequence, self.block_size, 0), expected_output + ) + + def test_process_story_no_highlights(self): + """ Processing a story with no highlights returns an empty list for the summary. + """ + raw_story = """It was the year of Our Lord one thousand seven hundred and + seventy-five.\n\nSpiritual revelations were conceded to England at that + favoured period, as at this.""" + _, summary_lines = process_story(raw_story) + self.assertEqual(summary_lines, []) + + def test_process_empty_story(self): + """ An empty story returns an empty collection of lines. + """ + raw_story = "" + story_lines, summary_lines = process_story(raw_story) + self.assertEqual(story_lines, []) + self.assertEqual(summary_lines, []) + + def test_process_story_with_missing_period(self): + raw_story = ( + "It was the year of Our Lord one thousand seven hundred and " + "seventy-five\n\nSpiritual revelations were conceded to England " + "at that favoured period, as at this.\n@highlight\n\nIt was the best of times" + ) + story_lines, summary_lines = process_story(raw_story) + + expected_story_lines = [ + "It was the year of Our Lord one thousand seven hundred and seventy-five.", + "Spiritual revelations were conceded to England at that favoured period, as at this.", + ] + self.assertEqual(expected_story_lines, story_lines) + + expected_summary_lines = ["It was the best of times."] + self.assertEqual(expected_summary_lines, summary_lines) + + def test_build_lm_labels_no_padding(self): + sequence = torch.tensor([1, 2, 3, 4]) + expected = sequence + np.testing.assert_array_equal( + build_lm_labels(sequence, 0).numpy(), expected.numpy() + ) + + def test_build_lm_labels(self): + sequence = torch.tensor([1, 2, 3, 4, 0, 0, 0]) + expected = torch.tensor([1, 2, 3, 4, -1, -1, -1]) + np.testing.assert_array_equal( + build_lm_labels(sequence, 0).numpy(), expected.numpy() + ) + + def test_build_mask_no_padding(self): + sequence = torch.tensor([1, 2, 3, 4]) + expected = torch.tensor([1, 1, 1, 1]) + np.testing.assert_array_equal( + build_mask(sequence, 0).numpy(), expected.numpy() + ) + + def test_build_mask(self): + sequence = torch.tensor([1, 2, 3, 4, 23, 23, 23]) + expected = torch.tensor([1, 1, 1, 1, 0, 0, 0]) + np.testing.assert_array_equal( + build_mask(sequence, 23).numpy(), expected.numpy() + ) + + def test_compute_token_type_ids(self): + separator = 101 + batch = torch.tensor( + [[1, 2, 3, 4, 5, 6], [1, 2, 3, 101, 5, 6], [1, 101, 3, 4, 101, 6]] + ) + expected = torch.tensor( + [[0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 1, 1], [0, 1, 1, 1, 0, 0]] + ) + + result = compute_token_type_ids(batch, separator) + np.testing.assert_array_equal(result, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/transformers/modeling_beam_search.py b/transformers/modeling_beam_search.py index 3a27625f90..171dcb7247 100644 --- a/transformers/modeling_beam_search.py +++ b/transformers/modeling_beam_search.py @@ -26,189 +26,220 @@ import torch from torch import nn -class ModelWithBeamSearch(nn.Module): +class TransformerBeamSearch(nn.Module): def __init__( self, model, + tokenizer, + batch_size, beam_size, - start_token_id, - end_token_id, - pad_token_id, min_length, max_length, - alpha, - block_trigram=True, + alpha=0, + block_repeating_trigram=True, ): """ Attributes: mask_word_id: token id that corresponds to the mask """ - super(ModelWithBeamSearch, self).__init__() + super(TransformerBeamSearch, self).__init__() self.model = model + self.tokenizer = tokenizer + + self.start_token_id = tokenizer.start_token_id + self.end_token_id = tokenizer.end_token_id + self.pad_token_id = tokenizer.pad_token_id + self.beam_size = beam_size - self.start_token_id = start_token_id - self.end_token_id = end_token_id - self.pad_token_id = pad_token_id self.min_length = min_length self.max_length = max_length - self.alpha = alpha - self.block_trigram = block_trigram - def forward(self, input_ids, **kwargs): - # Separate the encoder- and decoder- specific kwargs. A kwarg is - # decoder-specific it the key starts with `decoder_` + self.block_repeating_trigram = block_repeating_trigram + self.apply_length_penalty = False if alpha == 0 else True + self.alpha = alpha + + # State of the beam + self.hypotheses = [[] for _ in range(batch_size)] + self.batch_offset = torch.arange(batch_size, dtype=torch.long) + self.beam_offset = torch.arange( + 0, batch_size * self.beam_size, step=self.beam_size, dtype=torch.long + ) + self.growing_beam = torch.full( + (batch_size * self.beam_size, 1), self.start_token_id, dtype=torch.long + ) + self.topk_log_probabilities = torch.tensor( + [0.0] + [float("-inf")] * (self.beam_size - 1), dtype=torch.float + ).repeat(batch_size) + self.results = { + "prediction": [[] for _ in batch_size], + "scores": [[] for _ in batch_size], + } + self._step = 0 + self.is_done = False + + def step(self, log_probabilities): + """ Grows the beam by one step. """ + self._step += 1 + + # The batch size changes as some beams finish so we define _B + vocab_size = log_probabilities.size(-1) + _B = log_probabilities.size(0) // self.beam_size + + # Multiply each beam probability with the probability of the + # next token (conditioned on the words in the beam). + log_probabilities += self.topk_log_probabilities.view(-1, 1) + + self.enforce_min_length(log_probabilities) + if self.block_repeating_trigram: + self.remove_repeating_trigrams(log_probabilities, _B) + + # Find the `beam_size` (previous_beam + token) combinations with + # the highest score + topk_log_probabilities, topk_ids = log_probabilities.topk( + log_probabilities.view(_B, self.beam_size * vocab_size), + self.beam_size, + dim=1, + ) + + # Apply the length penalty. The +1 accounts for the [EOS] token + # that will be added if the beam ends. + topk_scores = topk_log_probabilities / self.length_penalty() + + # Retrieve the corresponding respective beam and token id + # topk_token_ids[i] will be added to topk_beam_ids[i] + topk_beam_ids = topk_ids.div(vocab_size) + topk_token_ids = topk_ids.fmod(vocab_size) + + # Retrieve the row index of the surviving beams in the original + # view of the log_probabilities tensor + surviving_beams_rows = (topk_beam_ids + self.beam_offset[:_B].view(-1, 1)).view( + -1 + ) + + # Append the last predictions + self.growing_beam = torch.cat( + [ + self.growing_beam.index_select(0, surviving_beams_rows), + topk_token_ids.view(-1, 1), + ], + 1, + ) + + # Check if any of the beam searches has ended during this + # growth step. Also if top beam (most probable) has ended + # for one element of the batch. + is_finished = topk_token_ids.eq(self.end_token_id) + self.enforce_max_length() + is_top_beam_finished = is_finished[:, 0].eq(1) + + # Save the finished searches + if is_finished.any(): + predictions = self.growing_beam.view( + -1, self.beam_size, self.growing_beam.size(1) + ) + for i in range(is_finished.size(0)): + if is_top_beam_finished[i]: + is_finished[i].fill_(1) + finished_hyp = is_finished[i].nonzero().view(-1) + + # Store finished hypotheses for this batch. + b = self.batch_offset[i] + for j in finished_hyp: + self.hypotheses[b].append((topk_scores[i, j], predictions[i, j, :])) + + # If the batch reached the end, save the best hypotheses + # in terms of length-penalized score. + if is_top_beam_finished[i]: + best_hyp = sorted( + self.hypotheses[b], key=lambda x: x[0], reverse=True + ) + best_score, best_prediction = best_hyp[0] + self.results["scores"][b].append(best_score) + self.results["predictions"][b].append(best_prediction) + + non_finished = is_top_beam_finished.eq(0).nonzero().view(-1) + if len(non_finished) == 0: + self.is_done = True + + # Remove finished batches for the next step. + topk_log_probabilities = topk_log_probabilities.index_select( + 0, non_finished + ) + self.batch_offset = self.batch_offset.index_select(0, non_finished) + self.growing_beam = predictions.index_select(0, non_finished).view( + -1, self.growing_beam.size(-1) + ) + + surviving_beams_rows = surviving_beams_rows.index_select(0, non_finished) + + return surviving_beams_rows + + def forward(self, encoder_input_ids, **kwargs): + # keyword arguments come in 3 flavors: encoder-specific (prefixed by + # `encoder_`), decoder-specific (prefixed by `decoder_`) and those + # that apply to the model as whole. + # We let the specific kwargs override the common ones in case of conflict. kwargs_encoder = { - argument: value + argument[len("encoder_"):]: value for argument, value in kwargs.items() - if not argument.startswith("decoder_") + if argument.startswith("encoder_") } kwargs_decoder = { argument[len("decoder_"):]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } + kwargs_common = { + argument: value + for argument, value in kwargs.items() + if not (argument.startswith("encoder_") or argument.startswith("decoder_")) + } + kwargs_decoder = dict(kwargs_common, **kwargs_decoder) + kwargs_encoder = dict(kwargs_common, **kwargs_encoder) - batch_size, _ = input_ids.size(0) - - # Variables that keep track of the status of the search - hypotheses = [[] for _ in range(batch_size)] - batch_offset = torch.arange(batch_size, dtype=torch.long) - beam_offset = torch.arange( - 0, - batch_size * self.beam_size, - step=self.beam_size, - dtype=torch.long, - ) - growing_beam = torch.full( - (batch_size * self.beam_size, 1), - self.start_token_id, - dtype=torch.long, - ) - topk_log_probabilities = torch.tensor( - [0.0] + [float("-inf")] * (self.beam_size - 1), - dtype=torch.float, - ).repeat(batch_size) - - # Forward pass on the encoder - encoder_outputs = self.encoder(input_ids, kwargs_encoder) + # forward pass on the encoder + encoder_outputs = self.model.encoder.forward(encoder_input_ids, kwargs_encoder) kwargs_decoder["encoder_hidden_states"] = tile( encoder_outputs, self.beam_size, dim=0 ) - results = {} - results["predictions"] = [[] for _ in batch_size] - results["scores"] = [[] for _ in batch_size] - + # grow the beam by generating sequences in an autoregressive way + self.growing_beam = torch.full( + (self.batch_size * self.beam_size, 1), self.start_token_id, dtype=torch.long + ) for step in range(self.max_length): - decoder_input = growing_beam[:, -1] - outputs = self.decoder(decoder_input, kwargs_decoder) + decoder_input = self.growing_beam[:, -1] + outputs = self.model.decoder(decoder_input, kwargs_decoder) log_probabilities = torch.nn.functional.log_softmax(outputs[1]) - vocab_size = log_probabilities.size(-1) + surviving_beams_rows = self.step(log_probabilities) + if self.is_done: + break - # The batch size changes as some beams finish so we define: - _B = log_probabilities.size(0) // self.beam_size - - # Multiply each beam probability with the probability of the - # next token (conditioned on the words in the beam). - log_probabilities += topk_log_probabilities.view(-1, 1) - - # if the beam has not attained the minimum required length we - # make the end token arbitrarily unlikely. - if step < self.min_length: - log_probabilities[self.end_token_id] = -1e20 - - # Remove repeating tri-grams - if(self.args.block_trigram): - if(step + 1 > 3): - for i in range(_B * self.beam_size): - tokens = [t for t in growing_beam[i]] - trigrams = [(tokens[i-1], tokens[i], tokens[i+1]) for i in range(1, len(words) - 1)] - last_trigram = tuple(trigrams[-1]) - if last_trigram in trigrams[:-1]: - log_probabilities[i] = -1e20 - - # Find the `beam_size` (previous_beam + token) combinations with - # the highest score - topk_log_probabilities, topk_ids = log_probabilities.topk( - log_probabilities.view(_B, self.beam_size * vocab_size), - self.beam_size, - dim=1 - ) - - # Apply the length penalty. The +1 accounts for the [EOS] token - # that will be added if the beam ends. - length_penalty = ((5.0 + (step + 1)) / 6.0) ** self.alpha - topk_scores = topk_log_probabilities / length_penalty - - # Retrieve the corresponding respective beam and token id - # topk_token_ids[i] will be added to topk_beam_ids[i] - topk_beam_ids = topk_ids.div(vocab_size) - topk_token_ids = topk_ids.fmod(vocab_size) - - # Retrieve the row index of the surviving beams in the original - # view of the log_probabilities tensor - surviving_beams_rows = ( - topk_beam_ids + beam_offset[:_B].view(-1, 1) - ).view(-1) - - # Append the last predictions - growing_beam = torch.cat( - [ - growing_beam.index_select(0, surviving_beams_rows), - topk_token_ids.view(-1, 1), - ], - 1, - ) - - # Check if any of the beam searches has ended during this - # growth step. Also if top beam (most probable) has ended - # for one element of the batch. - is_finished = topk_token_ids.eq(self.end_token_id) - if step + 1 == self.max_length: - is_finished.fill_(1) - is_top_beam_finished = is_finished[:, 0].eq(1) - - # Save the finished searches - if is_finished.any(): - predictions = growing_beam.view(-1, self.beam_size, growing_beam.size(1)) - for i in range(is_finished.size(0)): - if is_top_beam_finished[i]: - is_finished[i].fill_(1) - finished_hyp = is_finished[i].nonzero().view(-1) - - # Store finished hypotheses for this batch. - b = batch_offset[i] - for j in finished_hyp: - hypotheses[b].append((topk_scores[i, j], predictions[i, j, :])) - - # If the batch reached the end, save the best hypotheses - # in terms of length-penalized score. - if is_top_beam_finished[i]: - best_hyp = sorted( - hypotheses[b], key=lambda x: x[0], reverse=True - ) - best_score, best_prediction = best_hyp[0] - results["scores"][b].append(best_score) - results["predictions"][b].append(best_prediction) - - non_finished = is_top_beam_finished.eq(0).nonzero().view(-1) - if len(non_finished) == 0: - break - - # Remove finished batches for the next step. - topk_log_probabilities = topk_log_probabilities.index_select(0, non_finished) - batch_offset = batch_offset.index_select(0, non_finished) - growing_beam = predictions.index_select(0, non_finished).view( - -1, growing_beam.size(-1) - ) - - # Re-order the state for the next pass - surviving_beams_rows = surviving_beams_rows.index_select(0, non_finished) kwargs_decoder["encoder_hidden_states"] = kwargs_decoder[ "encoder_hidden_states" ].index_select(0, surviving_beams_rows) - return results + return self.results + + def remove_repeating_trigrams(self, log_probabilities, _B): + if(self._step + 1 > 3): + for i in range(_B * self.beam_size): + tokens = [t for t in self.growing_beam[i]] + trigrams = [(tokens[i-1], tokens[i], tokens[i+1]) for i in range(1, len(words) - 1)] + last_trigram = tuple(trigrams[-1]) + if last_trigram in trigrams[:-1]: + log_probabilities[i] = -1e20 + + def enforce_min_length(self): + if self._step < self.min_length: + self.log_probabilities[self.end_token_id] = -1e20 + + def enforce_max_length(self): + if self._step + 1 == self.max_length: + self.is_finished.fill_(1) + + def length_penalty(self): + return ((5.0 + (self._step + 1)) / 6.0) ** self.alpha def tile(x, count, dim=0): diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index 93f3c7e1f1..1081c8dd7b 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -632,6 +632,8 @@ class BertModel(BertPreTrainedModel): """ if attention_mask is None: attention_mask = torch.ones_like(input_ids) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones_like(input_ids) if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) @@ -660,12 +662,15 @@ class BertModel(BertPreTrainedModel): extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - # If a 2D encoder attention mask is provided for the cross-attention + # If a 2D ou 3D attention mask is provided for the cross-attention # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] - if encoder_attention_mask is not None: - encoder_attention_mask = encoder_attention_mask[:, None, None, :] - encoder_attention_mask = encoder_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility - encoder_attention_mask = (1.0 - encoder_attention_mask) * -10000.0 + if encoder_attention_mask.dim() == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if encoder_attention_mask.dim() == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + + encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -687,7 +692,7 @@ class BertModel(BertPreTrainedModel): attention_mask=extended_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask) + encoder_attention_mask=encoder_extended_attention_mask) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) @@ -788,8 +793,10 @@ class BertForMaskedLM(BertPreTrainedModel): in ``[0, ..., config.vocab_size]`` Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: - **loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + **masked_lm_loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: Masked language modeling loss. + **next_token_loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Next token prediction loss. **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) @@ -854,13 +861,13 @@ class BertForMaskedLM(BertPreTrainedModel): if lm_labels is not None: # we are doing next-token prediction; shift prediction scores and input ids by one - prediction_scores = prediction_scores[:, :-1, :] - lm_labels = lm_labels[:, 1:] + prediction_scores = prediction_scores[:, :-1, :].contiguous() + lm_labels = lm_labels[:, 1:].contiguous() loss_fct = CrossEntropyLoss(ignore_index=-1) - seq2seq_loss = loss_fct(prediction_scores.reshape(-1, self.config.vocab_size), lm_labels.reshape(-1)) - outputs = (seq2seq_loss,) + outputs + next_token_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1)) + outputs = (next_token_loss,) + outputs - return outputs # (mlm_or_seq2seq_loss), prediction_scores, (hidden_states), (attentions) + return outputs # (masked_lm_loss), (next_token_loss), prediction_scores, (hidden_states), (attentions) @add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """, diff --git a/transformers/modeling_seq2seq.py b/transformers/modeling_seq2seq.py index 2767dd2cd1..22898db9a1 100644 --- a/transformers/modeling_seq2seq.py +++ b/transformers/modeling_seq2seq.py @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class PreTrainedSeq2seq(nn.Module): r""" - :class:`~transformers.Seq2seq` is a generic model class that will be + :class:`~transformers.PreTrainedSeq2seq` is a generic model class that will be instantiated as a Seq2seq model with one of the base model classes of the library as encoder and (optionally) as decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class @@ -49,8 +49,7 @@ class PreTrainedSeq2seq(nn.Module): *model_args, **kwargs ): - r""" Instantiates an encoder and a decoder from one or two base classes - of the library from pre-trained model checkpoints. + r""" Instantiates an encoder and a decoder from one or two base classes of the library from pre-trained model checkpoints. The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) @@ -111,35 +110,44 @@ class PreTrainedSeq2seq(nn.Module): model = PreTrainedSeq2seq.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert """ - # Separate the encoder- and decoder- specific kwargs. A kwarg is - # decoder-specific it the key starts with `decoder_` + # keyword arguments come in 3 flavors: encoder-specific (prefixed by + # `encoder_`), decoder-specific (prefixed by `decoder_`) and those + # that apply to the model as a whole. + # We let the specific kwargs override the common ones in case of conflict. kwargs_encoder = { - argument: value + argument[len("encoder_"):]: value for argument, value in kwargs.items() - if not argument.startswith("decoder_") + if argument.startswith("encoder_") } kwargs_decoder = { - argument[len("decoder_") :]: value + argument[len("decoder_"):]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } + kwargs_common = { + argument: value + for argument, value in kwargs.items() + if not (argument.startswith("encoder_") or argument.startswith("decoder_")) + } + kwargs_decoder = dict(kwargs_common, **kwargs_decoder) + kwargs_encoder = dict(kwargs_common, **kwargs_encoder) # Load and initialize the encoder and decoder # The distinction between encoder and decoder at the model level is made # by the value of the flag `is_decoder` that we need to set correctly. - encoder = kwargs_encoder.pop("encoder_model", None) + encoder = kwargs_encoder.pop("model", None) if encoder is None: - kwargs_encoder["is_decoder"] = False encoder = AutoModel.from_pretrained( encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder ) + encoder.config.is_decoder = False decoder = kwargs_decoder.pop("model", None) if decoder is None: - kwargs_decoder["is_decoder"] = True decoder = AutoModelWithLMHead.from_pretrained( decoder_pretrained_model_name_or_path, **kwargs_decoder ) + decoder.config.is_decoder = True model = cls(encoder, decoder) @@ -169,37 +177,60 @@ class PreTrainedSeq2seq(nn.Module): decoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)`` Indices of decoder input sequence tokens in the vocabulary. """ - # Separate the encoder- and decoder- specific kwargs. A kwarg is - # decoder-specific it the key starts with `decoder_` + # keyword arguments come in 3 flavors: encoder-specific (prefixed by + # `encoder_`), decoder-specific (prefixed by `decoder_`) and those + # that apply to the model as whole. + # We let the specific kwargs override the common ones in case of conflict. kwargs_encoder = { - argument: value + argument[len("encoder_"):]: value for argument, value in kwargs.items() - if not argument.startswith("decoder_") + if argument.startswith("encoder_") } kwargs_decoder = { - argument[len("decoder_") :]: value + argument[len("decoder_"):]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } + kwargs_common = { + argument: value + for argument, value in kwargs.items() + if not (argument.startswith("encoder_") or argument.startswith("decoder_")) + } + kwargs_decoder = dict(kwargs_common, **kwargs_decoder) + kwargs_encoder = dict(kwargs_common, **kwargs_encoder) # Encode if needed (training, first prediction pass) - encoder_hidden_states = kwargs_encoder.pop("encoder_hidden_states", None) + encoder_hidden_states = kwargs_encoder.pop("hidden_states", None) if encoder_hidden_states is None: encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder) - encoder_hidden_states = encoder_outputs[0][ - -1 - ] # output of the encoder *stack* + encoder_hidden_states = encoder_outputs[0] # output the last layer hidden state else: encoder_outputs = () # Decode - kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states[None, :, :] + kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states + kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None) decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder) return decoder_outputs + encoder_outputs class Model2Model(PreTrainedSeq2seq): + r""" + :class:`~transformers.Model2Model` instantiates a Seq2Seq2 model + where both of the encoder and decoder are of the same family. If the + name of or that path to a pretrained model is specified the encoder and + the decoder will be initialized with the pretrained weight (the + cross-attention will be intialized randomly if its weights are not + present). + + It is possible to override this behavior and initialize, say, the decoder randomly + by creating it beforehand as follows + + config = BertConfig.from_pretrained() + decoder = BertForMaskedLM(config) + model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder) + """ def __init__(self, *args, **kwargs): super(Model2Model, self).__init__(*args, **kwargs) self.tie_weights() @@ -235,14 +266,10 @@ class Model2Model(PreTrainedSeq2seq): model = super(Model2Model, cls).from_pretrained( encoder_pretrained_model_name_or_path=pretrained_model_name_or_path, decoder_pretrained_model_name_or_path=pretrained_model_name_or_path, + *args, **kwargs ) - # Some architectures require for the decoder to be initialized randomly - # before fine-tuning. - if kwargs.get("decoder_initialize_randomly", False): - model.decoder.init_weights() - return model