[EncoderDecoder] Add Cross Attention for GPT2 (#6415)
* add cross attention layers for gpt2 * make gpt2 cross attention work * finish bert2gpt2 * add explicit comments * remove attention mask since not yet supported * revert attn mask in pipeline * Update src/transformers/modeling_gpt2.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_encoder_decoder.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
eb613b566a
commit
1d6e71e116
@@ -372,11 +372,16 @@ class GenerationMixin:
|
|||||||
|
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
if decoder_start_token_id is None:
|
if decoder_start_token_id is None:
|
||||||
|
# see if BOS token can be used for decoder_start_token_id
|
||||||
|
if bos_token_id is not None:
|
||||||
decoder_start_token_id = bos_token_id
|
decoder_start_token_id = bos_token_id
|
||||||
|
elif hasattr(self.config, "decoder") and hasattr(self.config.decoder, "bos_token_id"):
|
||||||
|
decoder_start_token_id = self.config.decoder.bos_token_id
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
|
||||||
|
)
|
||||||
|
|
||||||
assert (
|
|
||||||
decoder_start_token_id is not None
|
|
||||||
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
|
|
||||||
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
|
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
|
||||||
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
|
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
|
||||||
|
|
||||||
|
|||||||
@@ -287,6 +287,8 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
**kwargs_decoder,
|
**kwargs_decoder,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO(PVP): currently it is not possible to use `past`
|
||||||
|
# with the encoder/decoder framework -> should be implemented
|
||||||
return decoder_outputs + encoder_outputs
|
return decoder_outputs + encoder_outputs
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs):
|
||||||
@@ -299,15 +301,24 @@ class EncoderDecoderModel(PreTrainedModel):
|
|||||||
encoder_outputs = (past,)
|
encoder_outputs = (past,)
|
||||||
|
|
||||||
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids)
|
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids)
|
||||||
|
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
||||||
return {
|
input_dict = {
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"decoder_attention_mask": decoder_inputs["attention_mask"],
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
"decoder_input_ids": decoder_inputs["input_ids"],
|
"decoder_input_ids": decoder_inputs["input_ids"],
|
||||||
"encoder_outputs": encoder_outputs,
|
"encoder_outputs": encoder_outputs,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Ideally all models should have a `use_cache`
|
||||||
|
# leave following to ifs until all have it implemented
|
||||||
|
if "use_cache" in decoder_inputs:
|
||||||
|
input_dict["decoder_use_cache"] = decoder_inputs["use_cache"]
|
||||||
|
|
||||||
|
if "past_key_values" in decoder_inputs:
|
||||||
|
input_dict["decoder_past_key_values"] = decoder_inputs["past_key_values"]
|
||||||
|
|
||||||
|
return input_dict
|
||||||
|
|
||||||
def _reorder_cache(self, past, beam_idx):
|
def _reorder_cache(self, past, beam_idx):
|
||||||
# as a default encoder-decoder models do not re-order the past.
|
# apply decoder cache reordering here
|
||||||
# TODO(PVP): might have to be updated, e.g. if GPT2 is to be used as a decoder
|
return self.decoder._reorder_cache(past, beam_idx)
|
||||||
return past
|
|
||||||
|
|||||||
@@ -118,7 +118,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
|||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(self, nx, n_ctx, config, scale=False):
|
def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
||||||
@@ -131,8 +131,12 @@ class Attention(nn.Module):
|
|||||||
self.n_head = config.n_head
|
self.n_head = config.n_head
|
||||||
self.split_size = n_state
|
self.split_size = n_state
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
self.is_cross_attention = is_cross_attention
|
||||||
self.c_attn = Conv1D(n_state * 3, nx)
|
if self.is_cross_attention:
|
||||||
|
self.c_attn = Conv1D(2 * n_state, nx)
|
||||||
|
self.q_attn = Conv1D(n_state, nx)
|
||||||
|
else:
|
||||||
|
self.c_attn = Conv1D(3 * n_state, nx)
|
||||||
self.c_proj = Conv1D(n_state, nx)
|
self.c_proj = Conv1D(n_state, nx)
|
||||||
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
||||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||||
@@ -160,6 +164,9 @@ class Attention(nn.Module):
|
|||||||
if self.scale:
|
if self.scale:
|
||||||
w = w / (float(v.size(-1)) ** 0.5)
|
w = w / (float(v.size(-1)) ** 0.5)
|
||||||
nd, ns = w.size(-2), w.size(-1)
|
nd, ns = w.size(-2), w.size(-1)
|
||||||
|
|
||||||
|
if not self.is_cross_attention:
|
||||||
|
# if only "normal" attention layer implements causal mask
|
||||||
mask = self.bias[:, :, ns - nd : ns, :ns]
|
mask = self.bias[:, :, ns - nd : ns, :ns]
|
||||||
w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
|
w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
|
||||||
|
|
||||||
@@ -193,10 +200,26 @@ class Attention(nn.Module):
|
|||||||
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False
|
self,
|
||||||
|
hidden_states,
|
||||||
|
layer_past=None,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
use_cache=False,
|
||||||
|
output_attentions=False,
|
||||||
):
|
):
|
||||||
x = self.c_attn(x)
|
if encoder_hidden_states is not None:
|
||||||
query, key, value = x.split(self.split_size, dim=2)
|
assert hasattr(
|
||||||
|
self, "q_attn"
|
||||||
|
), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`."
|
||||||
|
query = self.q_attn(hidden_states)
|
||||||
|
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
else:
|
||||||
|
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
||||||
|
|
||||||
query = self.split_heads(query)
|
query = self.split_heads(query)
|
||||||
key = self.split_heads(key, k=True)
|
key = self.split_heads(key, k=True)
|
||||||
value = self.split_heads(value)
|
value = self.split_heads(value)
|
||||||
@@ -239,32 +262,64 @@ class MLP(nn.Module):
|
|||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
def __init__(self, n_ctx, config, scale=False):
|
def __init__(self, n_ctx, config, scale=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
nx = config.n_embd
|
hidden_size = config.n_embd
|
||||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * nx
|
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
||||||
self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
|
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||||
self.attn = Attention(nx, n_ctx, config, scale)
|
self.attn = Attention(hidden_size, n_ctx, config, scale)
|
||||||
self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
|
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
if config.add_cross_attention:
|
||||||
|
self.crossattention = Attention(hidden_size, n_ctx, config, scale, is_cross_attention=True)
|
||||||
|
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||||
self.mlp = MLP(inner_dim, config)
|
self.mlp = MLP(inner_dim, config)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False,
|
self,
|
||||||
|
hidden_states,
|
||||||
|
layer_past=None,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
use_cache=False,
|
||||||
|
output_attentions=False,
|
||||||
):
|
):
|
||||||
output_attn = self.attn(
|
attn_outputs = self.attn(
|
||||||
self.ln_1(x),
|
self.ln_1(hidden_states),
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
a = output_attn[0] # output_attn: a, present, (attentions)
|
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
||||||
|
outputs = attn_outputs[1:]
|
||||||
|
# residual connection
|
||||||
|
hidden_states = attn_output + hidden_states
|
||||||
|
|
||||||
x = x + a
|
if encoder_hidden_states is not None:
|
||||||
m = self.mlp(self.ln_2(x))
|
# add one self-attention block for cross-attention
|
||||||
x = x + m
|
assert hasattr(
|
||||||
|
self, "crossattention"
|
||||||
|
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
|
||||||
|
cross_attn_outputs = self.crossattention(
|
||||||
|
self.ln_cross_attn(hidden_states),
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
attn_output = cross_attn_outputs[0]
|
||||||
|
# residual connection
|
||||||
|
hidden_states = hidden_states + attn_output
|
||||||
|
outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights
|
||||||
|
|
||||||
outputs = [x] + output_attn[1:]
|
feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states))
|
||||||
return outputs # x, present, (attentions)
|
# residual connection
|
||||||
|
hidden_states = hidden_states + feed_forward_hidden_states
|
||||||
|
|
||||||
|
outputs = [hidden_states] + outputs
|
||||||
|
return outputs # hidden_states, present, (cross_attentions, attentions)
|
||||||
|
|
||||||
|
|
||||||
class GPT2PreTrainedModel(PreTrainedModel):
|
class GPT2PreTrainedModel(PreTrainedModel):
|
||||||
@@ -449,6 +504,8 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
@@ -506,7 +563,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||||
# this attention mask is more simple than the triangular masking of causal attention
|
# this attention mask is more simple than the triangular masking of causal attention
|
||||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||||
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
attention_mask = attention_mask[:, None, None, :]
|
||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
# masked positions, this operation will create a tensor which is 0.0 for
|
||||||
@@ -516,6 +573,17 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||||
attention_mask = (1.0 - attention_mask) * -10000.0
|
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||||
|
|
||||||
|
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||||
|
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
||||||
|
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||||
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||||
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||||
|
if encoder_attention_mask is None:
|
||||||
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||||
|
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||||
|
else:
|
||||||
|
encoder_attention_mask = None
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
# 1.0 in head_mask indicate we keep the head
|
# 1.0 in head_mask indicate we keep the head
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
# attention_probs has shape bsz x n_heads x N x N
|
||||||
@@ -546,6 +614,8 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask[i],
|
head_mask=head_mask[i],
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
@@ -593,17 +663,21 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||||
# only last token for inputs_ids if past is defined in kwargs
|
# only last token for inputs_ids if past is defined in kwargs
|
||||||
if past:
|
if past:
|
||||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||||
|
|
||||||
return {"input_ids": input_ids, "past_key_values": past, "use_cache": kwargs["use_cache"]}
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"past_key_values": past,
|
||||||
|
"use_cache": kwargs.get("use_cache"),
|
||||||
|
}
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
checkpoint="ctrl",
|
checkpoint="gpt2",
|
||||||
output_type=CausalLMOutputWithPast,
|
output_type=CausalLMOutputWithPast,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
)
|
)
|
||||||
@@ -616,6 +690,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
position_ids=None,
|
position_ids=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
labels=None,
|
labels=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
@@ -648,6 +724,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
|
|||||||
@@ -20,10 +20,9 @@ import unittest
|
|||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
# TODO(PVP): this line reruns all the tests in BertModelTest; not sure whether this can be prevented
|
|
||||||
# for now only run module with pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest
|
|
||||||
from .test_modeling_bert import BertModelTester
|
from .test_modeling_bert import BertModelTester
|
||||||
from .test_modeling_common import ids_tensor
|
from .test_modeling_common import ids_tensor
|
||||||
|
from .test_modeling_gpt2 import GPT2ModelTester
|
||||||
from .test_modeling_roberta import RobertaModelTester
|
from .test_modeling_roberta import RobertaModelTester
|
||||||
|
|
||||||
|
|
||||||
@@ -31,6 +30,7 @@ if is_torch_available():
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
BertModel,
|
BertModel,
|
||||||
BertLMHeadModel,
|
BertLMHeadModel,
|
||||||
|
GPT2LMHeadModel,
|
||||||
RobertaModel,
|
RobertaModel,
|
||||||
RobertaForCausalLM,
|
RobertaForCausalLM,
|
||||||
EncoderDecoderModel,
|
EncoderDecoderModel,
|
||||||
@@ -424,3 +424,59 @@ class RoBertaEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def get_pretrained_model(self):
|
def get_pretrained_model(self):
|
||||||
return EncoderDecoderModel.from_encoder_decoder_pretrained("roberta-base", "roberta-base")
|
return EncoderDecoderModel.from_encoder_decoder_pretrained("roberta-base", "roberta-base")
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2EncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||||
|
def get_encoder_decoder_model(self, config, decoder_config):
|
||||||
|
encoder_model = BertModel(config)
|
||||||
|
decoder_model = GPT2LMHeadModel(decoder_config)
|
||||||
|
return encoder_model, decoder_model
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
model_tester_encoder = BertModelTester(self, batch_size=13)
|
||||||
|
model_tester_decoder = GPT2ModelTester(self, batch_size=13)
|
||||||
|
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
|
||||||
|
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = encoder_config_and_inputs
|
||||||
|
(
|
||||||
|
decoder_config,
|
||||||
|
decoder_input_ids,
|
||||||
|
decoder_input_mask,
|
||||||
|
decoder_head_mask,
|
||||||
|
decoder_token_type_ids,
|
||||||
|
decoder_sequence_labels,
|
||||||
|
decoder_token_labels,
|
||||||
|
decoder_choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
) = decoder_config_and_inputs
|
||||||
|
|
||||||
|
# make sure that cross attention layers are added
|
||||||
|
decoder_config.add_cross_attention = True
|
||||||
|
# disable cache for now
|
||||||
|
decoder_config.use_cache = False
|
||||||
|
return {
|
||||||
|
"config": config,
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"attention_mask": input_mask,
|
||||||
|
"decoder_config": decoder_config,
|
||||||
|
"decoder_input_ids": decoder_input_ids,
|
||||||
|
"decoder_token_type_ids": decoder_token_type_ids,
|
||||||
|
"decoder_attention_mask": decoder_input_mask,
|
||||||
|
"decoder_sequence_labels": decoder_sequence_labels,
|
||||||
|
"decoder_token_labels": decoder_token_labels,
|
||||||
|
"decoder_choice_labels": decoder_choice_labels,
|
||||||
|
"encoder_hidden_states": encoder_hidden_states,
|
||||||
|
"labels": decoder_token_labels,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_pretrained_model(self):
|
||||||
|
return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "gpt2")
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from transformers import is_torch_available
|
|||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -62,27 +62,27 @@ class GPT2ModelTester:
|
|||||||
scope=None,
|
scope=None,
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.batch_size = 14
|
self.batch_size = batch_size
|
||||||
self.seq_length = 7
|
self.seq_length = seq_length
|
||||||
self.is_training = True
|
self.is_training = is_training
|
||||||
self.use_token_type_ids = True
|
self.use_token_type_ids = use_token_type_ids
|
||||||
self.use_input_mask = True
|
self.use_input_mask = use_input_mask
|
||||||
self.use_labels = True
|
self.use_labels = use_labels
|
||||||
self.use_mc_token_ids = True
|
self.use_mc_token_ids = use_mc_token_ids
|
||||||
self.vocab_size = 99
|
self.vocab_size = vocab_size
|
||||||
self.hidden_size = 32
|
self.hidden_size = hidden_size
|
||||||
self.num_hidden_layers = 5
|
self.num_hidden_layers = num_hidden_layers
|
||||||
self.num_attention_heads = 4
|
self.num_attention_heads = num_attention_heads
|
||||||
self.intermediate_size = 37
|
self.intermediate_size = intermediate_size
|
||||||
self.hidden_act = "gelu"
|
self.hidden_act = hidden_act
|
||||||
self.hidden_dropout_prob = 0.1
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
self.attention_probs_dropout_prob = 0, 1
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||||
self.max_position_embeddings = 512
|
self.max_position_embeddings = max_position_embeddings
|
||||||
self.type_vocab_size = 16
|
self.type_vocab_size = type_vocab_size
|
||||||
self.type_sequence_label_size = 2
|
self.type_sequence_label_size = type_sequence_label_size
|
||||||
self.initializer_range = 0.02
|
self.initializer_range = initializer_range
|
||||||
self.num_labels = 3
|
self.num_labels = num_labels
|
||||||
self.num_choices = 4
|
self.num_choices = num_choices
|
||||||
self.scope = None
|
self.scope = None
|
||||||
self.bos_token_id = vocab_size - 1
|
self.bos_token_id = vocab_size - 1
|
||||||
self.eos_token_id = vocab_size - 1
|
self.eos_token_id = vocab_size - 1
|
||||||
@@ -142,6 +142,35 @@ class GPT2ModelTester:
|
|||||||
choice_labels,
|
choice_labels,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_decoder(self):
|
||||||
|
(
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
head_mask,
|
||||||
|
token_type_ids,
|
||||||
|
mc_token_ids,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
) = self.prepare_config_and_inputs()
|
||||||
|
|
||||||
|
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
||||||
|
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
||||||
|
|
||||||
|
return (
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
input_mask,
|
||||||
|
head_mask,
|
||||||
|
token_type_ids,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
def create_and_check_gpt2_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
def create_and_check_gpt2_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||||
model = GPT2Model(config=config)
|
model = GPT2Model(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
|
|||||||
Reference in New Issue
Block a user