[EncoderDecoder] Add Cross Attention for GPT2 (#6415)
* add cross attention layers for gpt2 * make gpt2 cross attention work * finish bert2gpt2 * add explicit comments * remove attention mask since not yet supported * revert attn mask in pipeline * Update src/transformers/modeling_gpt2.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_encoder_decoder.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
eb613b566a
commit
1d6e71e116
@@ -20,10 +20,9 @@ import unittest
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
# TODO(PVP): this line reruns all the tests in BertModelTest; not sure whether this can be prevented
|
||||
# for now only run module with pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest
|
||||
from .test_modeling_bert import BertModelTester
|
||||
from .test_modeling_common import ids_tensor
|
||||
from .test_modeling_gpt2 import GPT2ModelTester
|
||||
from .test_modeling_roberta import RobertaModelTester
|
||||
|
||||
|
||||
@@ -31,6 +30,7 @@ if is_torch_available():
|
||||
from transformers import (
|
||||
BertModel,
|
||||
BertLMHeadModel,
|
||||
GPT2LMHeadModel,
|
||||
RobertaModel,
|
||||
RobertaForCausalLM,
|
||||
EncoderDecoderModel,
|
||||
@@ -424,3 +424,59 @@ class RoBertaEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
|
||||
def get_pretrained_model(self):
|
||||
return EncoderDecoderModel.from_encoder_decoder_pretrained("roberta-base", "roberta-base")
|
||||
|
||||
|
||||
class GPT2EncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = BertModel(config)
|
||||
decoder_model = GPT2LMHeadModel(decoder_config)
|
||||
return encoder_model, decoder_model
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
model_tester_encoder = BertModelTester(self, batch_size=13)
|
||||
model_tester_decoder = GPT2ModelTester(self, batch_size=13)
|
||||
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
|
||||
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = encoder_config_and_inputs
|
||||
(
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_input_mask,
|
||||
decoder_head_mask,
|
||||
decoder_token_type_ids,
|
||||
decoder_sequence_labels,
|
||||
decoder_token_labels,
|
||||
decoder_choice_labels,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = decoder_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
decoder_config.add_cross_attention = True
|
||||
# disable cache for now
|
||||
decoder_config.use_cache = False
|
||||
return {
|
||||
"config": config,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": input_mask,
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_token_type_ids": decoder_token_type_ids,
|
||||
"decoder_attention_mask": decoder_input_mask,
|
||||
"decoder_sequence_labels": decoder_sequence_labels,
|
||||
"decoder_token_labels": decoder_token_labels,
|
||||
"decoder_choice_labels": decoder_choice_labels,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"labels": decoder_token_labels,
|
||||
}
|
||||
|
||||
def get_pretrained_model(self):
|
||||
return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "gpt2")
|
||||
|
||||
Reference in New Issue
Block a user