BartForCausalLM analogs to ProphetNetForCausalLM (#9128)

* initiliaze bart4causalLM

* create BartDecoderWrapper, setters/getters

* delete spaces

* forward and additional methods

* update cache function, loss function, remove ngram* params in data class.

* add bartcausallm, bartdecoder testing

* correct bart for causal lm

* remove at

* add mbart as well

* up

* fix typo

* up

* correct

* add pegasusforcausallm

* add blenderbotforcausallm

* add blenderbotsmallforcausallm

* add marianforcausallm

* add test for MarianForCausalLM

* add Pegasus test

* add BlenderbotSmall test

* add blenderbot test

* fix a fail

* fix an import fail

* a fix

* fix

* Update modeling_pegasus.py

* fix models

* fix inputs_embeds setting getter

* adapt tests

* correct repo utils check

* finish test improvement

* fix tf models as well

* make style

* make fix-copies

* fix copies

* run all tests

* last changes

* fix all tests

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
demSd
2021-02-04 09:56:12 +01:00
committed by GitHub
parent 7898fc03b1
commit 00031785a8
38 changed files with 3595 additions and 177 deletions

View File

@@ -20,6 +20,7 @@ import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_modeling_bart import BartStandaloneDecoderModelTester
from .test_modeling_bert import BertModelTester
from .test_modeling_bert_generation import BertGenerationEncoderTester
from .test_modeling_common import ids_tensor
@@ -34,6 +35,7 @@ if is_torch_available():
from transformers import (
AutoTokenizer,
BartForCausalLM,
BertGenerationDecoder,
BertGenerationEncoder,
BertLMHeadModel,
@@ -828,3 +830,57 @@ class ProphetNetEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
def test_encoder_decoder_model_shared_weights(self):
pass
@require_torch
class BartEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = BertModel(config)
decoder_model = BartForCausalLM(decoder_config)
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester_encoder = BertModelTester(self, batch_size=13)
model_tester_decoder = BartStandaloneDecoderModelTester(
self, batch_size=13, d_model=32, max_position_embeddings=512
)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = encoder_config_and_inputs
(
decoder_config,
decoder_input_ids,
decoder_attention_mask,
encoder_hidden_states,
encoder_attention_mask,
lm_labels,
) = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
# disable cache for now
decoder_config.use_cache = False
return {
"config": config,
"input_ids": input_ids,
"attention_mask": input_mask,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"encoder_hidden_states": encoder_hidden_states,
"labels": lm_labels,
}
def get_pretrained_model(self):
return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "facebook/bart-large")
def test_encoder_decoder_model_shared_weights(self):
pass