[Templates] Adapt Bert (#9284)
* adapt templates * adapt config * add test as well * fix output type * fix cache false naming * finish tests * last fix
This commit is contained in:
committed by
GitHub
parent
88ef8893cd
commit
6c091abef2
@@ -71,6 +71,9 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
|
|||||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
|
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
|
||||||
The epsilon used by the layer normalization layers.
|
The epsilon used by the layer normalization layers.
|
||||||
|
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||||
|
relevant if ``config.is_decoder=True``.
|
||||||
{% else -%}
|
{% else -%}
|
||||||
vocab_size (:obj:`int`, `optional`, defaults to 50265):
|
vocab_size (:obj:`int`, `optional`, defaults to 50265):
|
||||||
Vocabulary size of the {{cookiecutter.modelname}} model. Defines the number of different tokens that can be represented by the
|
Vocabulary size of the {{cookiecutter.modelname}} model. Defines the number of different tokens that can be represented by the
|
||||||
@@ -146,6 +149,7 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
|
|||||||
type_vocab_size=2,
|
type_vocab_size=2,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
layer_norm_eps=1e-12,
|
layer_norm_eps=1e-12,
|
||||||
|
use_cache=True,
|
||||||
is_encoder_decoder=False,
|
is_encoder_decoder=False,
|
||||||
{% else -%}
|
{% else -%}
|
||||||
vocab_size=50265,
|
vocab_size=50265,
|
||||||
@@ -199,6 +203,7 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig):
|
|||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
self.type_vocab_size = type_vocab_size
|
self.type_vocab_size = type_vocab_size
|
||||||
self.layer_norm_eps = layer_norm_eps
|
self.layer_norm_eps = layer_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
{% else -%}
|
{% else -%}
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.encoder_ffn_dim = encoder_ffn_dim
|
self.encoder_ffn_dim = encoder_ffn_dim
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ from ...file_utils import (
|
|||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
CausalLMOutputWithCrossAttentions,
|
CausalLMOutputWithCrossAttentions,
|
||||||
MaskedLMOutput,
|
MaskedLMOutput,
|
||||||
MultipleChoiceModelOutput,
|
MultipleChoiceModelOutput,
|
||||||
@@ -160,7 +160,9 @@ class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
|
|||||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||||
|
|
||||||
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
|
def forward(
|
||||||
|
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
||||||
|
):
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
input_shape = input_ids.size()
|
input_shape = input_ids.size()
|
||||||
else:
|
else:
|
||||||
@@ -169,7 +171,7 @@ class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
|
|||||||
seq_length = input_shape[1]
|
seq_length = input_shape[1]
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = self.position_ids[:, :seq_length]
|
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||||
|
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
||||||
@@ -211,6 +213,8 @@ class {{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
|
|||||||
self.max_position_embeddings = config.max_position_embeddings
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
||||||
|
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
|
||||||
def transpose_for_scores(self, x):
|
def transpose_for_scores(self, x):
|
||||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
||||||
x = x.view(*new_x_shape)
|
x = x.view(*new_x_shape)
|
||||||
@@ -223,6 +227,7 @@ class {{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
):
|
):
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
@@ -230,17 +235,37 @@ class {{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
|
|||||||
# If this is instantiated as a cross-attention module, the keys
|
# If this is instantiated as a cross-attention module, the keys
|
||||||
# and values come from an encoder; the attention mask needs to be
|
# and values come from an encoder; the attention mask needs to be
|
||||||
# such that the encoder's padding tokens are not attended to.
|
# such that the encoder's padding tokens are not attended to.
|
||||||
if encoder_hidden_states is not None:
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
mixed_key_layer = self.key(encoder_hidden_states)
|
|
||||||
mixed_value_layer = self.value(encoder_hidden_states)
|
if is_cross_attention and past_key_value is not None:
|
||||||
|
# reuse k,v, cross_attentions
|
||||||
|
key_layer = past_key_value[0]
|
||||||
|
value_layer = past_key_value[1]
|
||||||
attention_mask = encoder_attention_mask
|
attention_mask = encoder_attention_mask
|
||||||
|
elif is_cross_attention:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
||||||
|
attention_mask = encoder_attention_mask
|
||||||
|
elif past_key_value is not None:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
|
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||||
|
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||||
else:
|
else:
|
||||||
mixed_key_layer = self.key(hidden_states)
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
mixed_value_layer = self.value(hidden_states)
|
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
||||||
|
|
||||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
|
||||||
value_layer = self.transpose_for_scores(mixed_value_layer)
|
if self.is_decoder:
|
||||||
|
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||||
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||||
|
# key/value_states (first "if" case)
|
||||||
|
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||||
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||||
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||||
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||||
|
past_key_value = (key_layer, value_layer)
|
||||||
|
|
||||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||||
@@ -284,6 +309,9 @@ class {{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
|
|||||||
context_layer = context_layer.view(*new_context_layer_shape)
|
context_layer = context_layer.view(*new_context_layer_shape)
|
||||||
|
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (past_key_value,)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@@ -335,6 +363,7 @@ class {{cookiecutter.camelcase_modelname}}Attention(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
):
|
):
|
||||||
self_outputs = self.self(
|
self_outputs = self.self(
|
||||||
@@ -343,6 +372,7 @@ class {{cookiecutter.camelcase_modelname}}Attention(nn.Module):
|
|||||||
head_mask,
|
head_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
past_key_value,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
)
|
)
|
||||||
attention_output = self.output(self_outputs[0], hidden_states)
|
attention_output = self.output(self_outputs[0], hidden_states)
|
||||||
@@ -403,36 +433,60 @@ class {{cookiecutter.camelcase_modelname}}Layer(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_value=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
):
|
):
|
||||||
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
self_attention_outputs = self.attention(
|
self_attention_outputs = self.attention(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
head_mask,
|
head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
past_key_value=self_attn_past_key_value,
|
||||||
)
|
)
|
||||||
attention_output = self_attention_outputs[0]
|
attention_output = self_attention_outputs[0]
|
||||||
|
|
||||||
|
# if decoder, the last output is tuple of self-attn cache
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = self_attention_outputs[1:-1]
|
||||||
|
present_key_value = self_attention_outputs[-1]
|
||||||
|
else:
|
||||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||||
|
|
||||||
|
cross_attn_present_key_value = None
|
||||||
if self.is_decoder and encoder_hidden_states is not None:
|
if self.is_decoder and encoder_hidden_states is not None:
|
||||||
assert hasattr(
|
assert hasattr(
|
||||||
self, "crossattention"
|
self, "crossattention"
|
||||||
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
|
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
|
||||||
|
|
||||||
|
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
||||||
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||||
cross_attention_outputs = self.crossattention(
|
cross_attention_outputs = self.crossattention(
|
||||||
attention_output,
|
attention_output,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
head_mask,
|
head_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
cross_attn_past_key_value,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
)
|
)
|
||||||
attention_output = cross_attention_outputs[0]
|
attention_output = cross_attention_outputs[0]
|
||||||
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
||||||
|
|
||||||
|
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
||||||
|
cross_attn_present_key_value = cross_attention_outputs[-1]
|
||||||
|
present_key_value = present_key_value + cross_attn_present_key_value
|
||||||
|
|
||||||
layer_output = apply_chunking_to_forward(
|
layer_output = apply_chunking_to_forward(
|
||||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||||
)
|
)
|
||||||
outputs = (layer_output,) + outputs
|
outputs = (layer_output,) + outputs
|
||||||
|
|
||||||
|
# if decoder, return the attn key/values as the last output
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (present_key_value,)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def feed_forward_chunk(self, attention_output):
|
def feed_forward_chunk(self, attention_output):
|
||||||
@@ -455,6 +509,8 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
@@ -462,17 +518,19 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
|
|||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
|
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||||
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
if getattr(self.config, "gradient_checkpointing", False):
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
return module(*inputs, output_attentions)
|
return module(*inputs, past_key_value, output_attentions)
|
||||||
|
|
||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
@@ -491,9 +549,13 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
|
|||||||
layer_head_mask,
|
layer_head_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
|
past_key_value,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[-1],)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||||
if self.config.add_cross_attention:
|
if self.config.add_cross_attention:
|
||||||
@@ -505,11 +567,18 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
|
|||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
v
|
v
|
||||||
for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
|
for v in [
|
||||||
|
hidden_states,
|
||||||
|
next_decoder_cache,
|
||||||
|
all_hidden_states,
|
||||||
|
all_self_attentions,
|
||||||
|
all_cross_attentions,
|
||||||
|
]
|
||||||
if v is not None
|
if v is not None
|
||||||
)
|
)
|
||||||
return BaseModelOutputWithCrossAttentions(
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_decoder_cache,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attentions,
|
attentions=all_self_attentions,
|
||||||
cross_attentions=all_cross_attentions,
|
cross_attentions=all_cross_attentions,
|
||||||
@@ -699,7 +768,7 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
|
|||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
checkpoint="{{cookiecutter.checkpoint_identifier}}",
|
checkpoint="{{cookiecutter.checkpoint_identifier}}",
|
||||||
output_type=BaseModelOutputWithCrossAttentions,
|
output_type=BaseModelOutputWithPastAndCrossAttentions,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
)
|
)
|
||||||
def forward(
|
def forward(
|
||||||
@@ -712,6 +781,8 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
|
|||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
@@ -727,6 +798,14 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
|
|||||||
|
|
||||||
- 1 for tokens that are **not masked**,
|
- 1 for tokens that are **not masked**,
|
||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
|
use_cache (:obj:`bool`, `optional`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`).
|
||||||
"""
|
"""
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@@ -734,19 +813,30 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
|
|||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if self.config.is_decoder:
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
else:
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
elif input_ids is not None:
|
elif input_ids is not None:
|
||||||
input_shape = input_ids.size()
|
input_shape = input_ids.size()
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
elif inputs_embeds is not None:
|
elif inputs_embeds is not None:
|
||||||
input_shape = inputs_embeds.size()[:-1]
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
|
||||||
|
# past_key_values_length
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||||
|
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones(input_shape, device=device)
|
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||||
|
|
||||||
@@ -773,7 +863,11 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
|
|||||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
embedding_output = self.embeddings(
|
embedding_output = self.embeddings(
|
||||||
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
past_key_values_length=past_key_values_length,
|
||||||
)
|
)
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
embedding_output,
|
embedding_output,
|
||||||
@@ -781,6 +875,8 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
|
|||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_extended_attention_mask,
|
encoder_attention_mask=encoder_extended_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -790,8 +886,9 @@ class {{cookiecutter.camelcase_modelname}}Model({{cookiecutter.camelcase_modelna
|
|||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (sequence_output,) + encoder_outputs[1:]
|
return (sequence_output,) + encoder_outputs[1:]
|
||||||
|
|
||||||
return BaseModelOutputWithCrossAttentions(
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=sequence_output,
|
last_hidden_state=sequence_output,
|
||||||
|
past_key_values=encoder_outputs.past_key_values,
|
||||||
hidden_states=encoder_outputs.hidden_states,
|
hidden_states=encoder_outputs.hidden_states,
|
||||||
attentions=encoder_outputs.attentions,
|
attentions=encoder_outputs.attentions,
|
||||||
cross_attentions=encoder_outputs.cross_attentions,
|
cross_attentions=encoder_outputs.cross_attentions,
|
||||||
@@ -935,7 +1032,9 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
|||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
encoder_attention_mask=None,
|
encoder_attention_mask=None,
|
||||||
|
past_key_values=None,
|
||||||
labels=None,
|
labels=None,
|
||||||
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
@@ -950,10 +1049,18 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
|||||||
|
|
||||||
- 1 for tokens that are **not masked**,
|
- 1 for tokens that are **not masked**,
|
||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
|
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
||||||
|
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
||||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
||||||
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
||||||
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``.
|
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``.
|
||||||
|
use_cache (:obj:`bool`, `optional`):
|
||||||
|
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
||||||
|
decoding (see :obj:`past_key_values`).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
@@ -983,6 +1090,8 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
|||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -1006,20 +1115,31 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
|||||||
return CausalLMOutputWithCrossAttentions(
|
return CausalLMOutputWithCrossAttentions(
|
||||||
loss=lm_loss,
|
loss=lm_loss,
|
||||||
logits=prediction_scores,
|
logits=prediction_scores,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
cross_attentions=outputs.cross_attentions,
|
cross_attentions=outputs.cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
||||||
input_shape = input_ids.shape
|
input_shape = input_ids.shape
|
||||||
|
|
||||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = input_ids.new_ones(input_shape)
|
attention_mask = input_ids.new_ones(input_shape)
|
||||||
|
|
||||||
|
# cut decoder_input_ids if past is used
|
||||||
|
if past is not None:
|
||||||
|
input_ids = input_ids[:, -1:]
|
||||||
|
|
||||||
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||||
|
|
||||||
|
def _reorder_cache(self, past, beam_idx):
|
||||||
|
reordered_past = ()
|
||||||
|
for layer_past in past:
|
||||||
|
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],)
|
||||||
|
return reordered_past
|
||||||
|
|
||||||
class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
|
class {{cookiecutter.camelcase_modelname}}ClassificationHead(nn.Module):
|
||||||
"""Head for sentence-level classification tasks."""
|
"""Head for sentence-level classification tasks."""
|
||||||
|
|
||||||
|
|||||||
@@ -224,6 +224,68 @@ class {{cookiecutter.camelcase_modelname}}ModelTester:
|
|||||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||||
|
|
||||||
|
def create_and_check_decoder_model_past_large_inputs(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
input_ids,
|
||||||
|
token_type_ids,
|
||||||
|
input_mask,
|
||||||
|
sequence_labels,
|
||||||
|
token_labels,
|
||||||
|
choice_labels,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
):
|
||||||
|
config.is_decoder = True
|
||||||
|
config.add_cross_attention = True
|
||||||
|
model = {{cookiecutter.camelcase_modelname}}ForCausalLM(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# first forward pass
|
||||||
|
outputs = model(
|
||||||
|
input_ids,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
past_key_values = outputs.past_key_values
|
||||||
|
|
||||||
|
# create hypothetical multiple next token and extent to next_input_ids
|
||||||
|
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||||
|
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
|
||||||
|
|
||||||
|
# append to next input_ids and
|
||||||
|
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
|
||||||
|
|
||||||
|
output_from_no_past = model(
|
||||||
|
next_input_ids,
|
||||||
|
attention_mask=next_attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)["hidden_states"][0]
|
||||||
|
output_from_past = model(
|
||||||
|
next_tokens,
|
||||||
|
attention_mask=next_attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)["hidden_states"][0]
|
||||||
|
|
||||||
|
# select random slice
|
||||||
|
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||||
|
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
||||||
|
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
||||||
|
|
||||||
|
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||||
|
|
||||||
|
# test that outputs are equal for slice
|
||||||
|
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
||||||
|
|
||||||
def create_and_check_for_question_answering(
|
def create_and_check_for_question_answering(
|
||||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||||
):
|
):
|
||||||
@@ -336,6 +398,10 @@ class {{cookiecutter.camelcase_modelname}}ModelTest(ModelTesterMixin, unittest.T
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
|
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_decoder_model_past_with_large_inputs(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
||||||
|
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||||
|
|
||||||
def test_for_question_answering(self):
|
def test_for_question_answering(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user