adding T5 model
This commit is contained in:
@@ -217,9 +217,7 @@ class PreTrainedEncoderDecoder(nn.Module):
|
||||
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
|
||||
if encoder_hidden_states is None:
|
||||
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
|
||||
encoder_hidden_states = encoder_outputs[
|
||||
0
|
||||
] # output the last layer hidden state
|
||||
encoder_hidden_states = encoder_outputs[0]
|
||||
else:
|
||||
encoder_outputs = ()
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2018 T5 Authors and HuggingFace Inc. team.
|
||||
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -20,11 +20,14 @@ import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import math
|
||||
import sys
|
||||
import itertools
|
||||
from io import open
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from .modeling_utils import PreTrainedModel, prune_linear_layer
|
||||
@@ -119,31 +122,389 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
|
||||
# - PreTrainedModel for the models (it-self a sub-class of torch.nn.Module)
|
||||
####################################################
|
||||
|
||||
class T5Layer(nn.Module):
|
||||
class T5DenseReluDense(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(T5Layer, self).__init__()
|
||||
self.attention = T5Attention(config)
|
||||
self.intermediate = T5Intermediate(config)
|
||||
self.output = T5Output(config)
|
||||
super(T5DenseReluDense, self).__init__()
|
||||
self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
|
||||
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
||||
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
||||
attention_output = attention_outputs[0]
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
||||
def forward(self, hidden_states):
|
||||
h = self.wi(hidden_states)
|
||||
h = F.relu(h)
|
||||
h = self.dropout(h)
|
||||
h = self.wo(h)
|
||||
return h
|
||||
|
||||
|
||||
class T5LayerFF(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(T5LayerFF, self).__init__()
|
||||
self.DenseReluDense = T5DenseReluDense(config)
|
||||
self.layer_norm = nn.LayerNorm(config.layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
norm_x = self.layer_norm(hidden_states)
|
||||
y = self.DenseReluDense(norm_x)
|
||||
layer_output = hidden_states + self.dropout(y)
|
||||
return layer_output
|
||||
|
||||
|
||||
class T5Attention(nn.Module):
|
||||
NEW_ID = itertools.count()
|
||||
|
||||
def __init__(self, config):
|
||||
super(T5Attention, self).__init__()
|
||||
self.layer_id = next(T5Attention.NEW_ID)
|
||||
|
||||
self.output_attentions = config.output_attentions
|
||||
self.relative_attention_num_buckets = config.relative_attention_num_buckets
|
||||
self.dim = config.d_model
|
||||
self.n_heads = config.num_heads
|
||||
self.dropout = config.dropout_rate
|
||||
assert self.dim % self.n_heads == 0
|
||||
|
||||
self.q = nn.Linear(self.dim, self.dim, bias=False)
|
||||
self.k = nn.Linear(self.dim, self.dim, bias=False)
|
||||
self.v = nn.Linear(self.dim, self.dim, bias=False)
|
||||
self.o = nn.Linear(self.dim, self.dim, bias=False)
|
||||
|
||||
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
|
||||
self.pruned_heads = set()
|
||||
|
||||
def prune_heads(self, heads):
|
||||
attention_head_size = self.dim // self.n_heads
|
||||
if len(heads) == 0:
|
||||
return
|
||||
mask = torch.ones(self.n_heads, attention_head_size)
|
||||
heads = set(heads) - self.pruned_heads
|
||||
for head in heads:
|
||||
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
|
||||
mask[head] = 0
|
||||
mask = mask.view(-1).contiguous().eq(1)
|
||||
index = torch.arange(len(mask))[mask].long()
|
||||
# Prune linear layers
|
||||
self.q = prune_linear_layer(self.q, index)
|
||||
self.k = prune_linear_layer(self.k, index)
|
||||
self.v = prune_linear_layer(self.v, index)
|
||||
self.o = prune_linear_layer(self.o, index, dim=1)
|
||||
# Update hyper params
|
||||
self.n_heads = self.n_heads - len(heads)
|
||||
self.dim = attention_head_size * self.n_heads
|
||||
self.pruned_heads = self.pruned_heads.union(heads)
|
||||
|
||||
@staticmethod
|
||||
def _relative_position_bucket(relative_position,
|
||||
bidirectional=True,
|
||||
num_buckets=32,
|
||||
max_distance=128):
|
||||
"""
|
||||
Adapted from Mesh Tensorflow:
|
||||
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
|
||||
|
||||
Translate relative position to a bucket number for relative attention.
|
||||
The relative position is defined as memory_position - query_position, i.e.
|
||||
the distance in tokens from the attending position to the attended-to
|
||||
position. If bidirectional=False, then positive relative positions are
|
||||
invalid.
|
||||
We use smaller buckets for small absolute relative_position and larger buckets
|
||||
for larger absolute relative_positions. All relative positions >=max_distance
|
||||
map to the same bucket. All relative positions <=-max_distance map to the
|
||||
same bucket. This should allow for more graceful generalization to longer
|
||||
sequences than the model has been trained on.
|
||||
Args:
|
||||
relative_position: an int32 Tensor
|
||||
bidirectional: a boolean - whether the attention is bidirectional
|
||||
num_buckets: an integer
|
||||
max_distance: an integer
|
||||
Returns:
|
||||
a Tensor with the same shape as relative_position, containing int32
|
||||
values in the range [0, num_buckets)
|
||||
"""
|
||||
ret = 0
|
||||
n = -relative_position
|
||||
if bidirectional:
|
||||
num_buckets //= 2
|
||||
ret += (n < 0).to(torch.long) * num_buckets # mtf.to_int32(mtf.less(n, 0)) * num_buckets
|
||||
n = torch.abs(n)
|
||||
else:
|
||||
n = torch.max(n, 0)
|
||||
# now n is in the range [0, inf)
|
||||
|
||||
# half of the buckets are for exact increments in positions
|
||||
max_exact = num_buckets // 2
|
||||
is_small = (n < max_exact)
|
||||
|
||||
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
||||
val_if_large = max_exact + (
|
||||
torch.log(n.float() / max_exact)
|
||||
/ math.log(max_distance / max_exact) * (num_buckets - max_exact)).to(torch.long)
|
||||
val_if_large = torch.min(val_if_large, num_buckets - 1)
|
||||
|
||||
ret += torch.where(is_small, n, val_if_large)
|
||||
return ret
|
||||
|
||||
def compute_bias(self, qlen, klen):
|
||||
""" Compute binned relative position bias """
|
||||
context_position = torch.arange(qlen, dtype=torch.long)[:, None]
|
||||
memory_position = torch.arange(klen, dtype=torch.long)[None, :]
|
||||
relative_position = memory_position - context_position # shape (qlen, klen)
|
||||
rp_bucket = self._relative_position_bucket(relative_position,
|
||||
bidirectional=not self.is_decoder,
|
||||
num_buckets=self.relative_attention_num_buckets)
|
||||
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
|
||||
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen)
|
||||
return values
|
||||
|
||||
def forward(self, input, mask, kv=None, position_bias=None, cache=None, head_mask=None):
|
||||
"""
|
||||
Self-attention (if kv is None) or attention over source sentence (provided by kv).
|
||||
"""
|
||||
# Input is (bs, qlen, dim)
|
||||
# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
|
||||
bs, qlen, dim = input.size()
|
||||
if kv is None:
|
||||
klen = qlen if cache is None else cache['slen'] + qlen
|
||||
else:
|
||||
klen = kv.size(1)
|
||||
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
|
||||
n_heads = self.n_heads
|
||||
dim_per_head = self.dim // n_heads
|
||||
mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen)
|
||||
|
||||
def shape(x):
|
||||
""" projection """
|
||||
return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
|
||||
|
||||
def unshape(x):
|
||||
""" compute context """
|
||||
return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
|
||||
|
||||
q = shape(self.q(input)) # (bs, n_heads, qlen, dim_per_head)
|
||||
if kv is None:
|
||||
k = shape(self.k(input)) # (bs, n_heads, qlen, dim_per_head)
|
||||
v = shape(self.v(input)) # (bs, n_heads, qlen, dim_per_head)
|
||||
elif cache is None or self.layer_id not in cache:
|
||||
k = v = kv
|
||||
k = shape(self.k(k)) # (bs, n_heads, qlen, dim_per_head)
|
||||
v = shape(self.v(v)) # (bs, n_heads, qlen, dim_per_head)
|
||||
|
||||
if cache is not None:
|
||||
if self.layer_id in cache:
|
||||
if kv is None:
|
||||
k_, v_ = cache[self.layer_id]
|
||||
k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head)
|
||||
v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head)
|
||||
else:
|
||||
k, v = cache[self.layer_id]
|
||||
cache[self.layer_id] = (k, v)
|
||||
|
||||
# q = q / math.sqrt(dim_per_head) # No scaling in T5
|
||||
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
|
||||
|
||||
if position_bias is None:
|
||||
position_bias = self.compute_bias(qlen, klen)
|
||||
scores += position_bias
|
||||
|
||||
mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
|
||||
scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
|
||||
|
||||
weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen)
|
||||
weights = F.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
weights = weights * head_mask
|
||||
|
||||
context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
|
||||
context = unshape(context) # (bs, qlen, dim)
|
||||
|
||||
context = self.o(context)
|
||||
|
||||
outputs = (context,)
|
||||
if self.output_attentions:
|
||||
outputs = outputs + (weights,)
|
||||
return outputs
|
||||
|
||||
|
||||
class T5LayerSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(T5LayerSelfAttention, self).__init__()
|
||||
self.SelfAttention = T5Attention(config)
|
||||
self.layer_norm = nn.LayerNorm(config.layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
|
||||
class T5PreTrainedModel(PreTrainedModel):
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
||||
norm_x = self.layer_norm(hidden_states)
|
||||
attention_output = self.SelfAttention(norm_x,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask)
|
||||
y = attention_output[0]
|
||||
layer_output = hidden_states + self.dropout(y)
|
||||
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
class T5LayerCrossAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(T5LayerCrossAttention, self).__init__()
|
||||
self.EncDecAttention = T5Attention(config)
|
||||
self.layer_norm = nn.LayerNorm(config.layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
|
||||
def forward(self, hidden_states, kv, attention_mask=None, head_mask=None):
|
||||
norm_x = self.layer_norm(hidden_states)
|
||||
attention_output = self.EncDecAttention(norm_x,
|
||||
kv=kv,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask)
|
||||
y = attention_output[0]
|
||||
layer_output = hidden_states + self.dropout(y)
|
||||
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
class T5Block(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(T5Block, self).__init__()
|
||||
self.is_decoder = config.is_decoder
|
||||
self.layer_000 = T5LayerSelfAttention(config)
|
||||
if self.is_decoder:
|
||||
self.layer_001 = T5LayerCrossAttention(config)
|
||||
self.layer_002 = T5LayerFF(config)
|
||||
else:
|
||||
self.layer_001 = T5LayerFF(config)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None,
|
||||
encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None):
|
||||
self_attention_outputs = self.layer_000(hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask)
|
||||
hidden_states = self_attention_outputs[0]
|
||||
outputs = self_attention_outputs[1:]
|
||||
|
||||
if self.is_decoder:
|
||||
cross_attention_outputs = self.layer_001(hidden_states,
|
||||
kv=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
head_mask=head_mask)
|
||||
hidden_states = cross_attention_outputs[0]
|
||||
outputs = cross_attention_outputs[1:] + outputs
|
||||
hidden_states = self.layer_002(hidden_states)
|
||||
else:
|
||||
hidden_states = self.layer_001(hidden_states)
|
||||
|
||||
outputs = (hidden_states,) + outputs # add attentions if we output them
|
||||
return outputs
|
||||
|
||||
|
||||
class T5Stack(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(T5Stack, self).__init__()
|
||||
self.blocks = nn.ModuleList([T5Block(config) for _ in range(config.num_layers)])
|
||||
self.final_layer_norm = nn.LayerNorm(config.layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
|
||||
def forward(self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
head_mask=None):
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
if attention_mask.dim() == 3:
|
||||
extended_attention_mask = attention_mask[:, None, :, :]
|
||||
|
||||
# Provided a padding mask of dimensions [batch_size, seq_length]
|
||||
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
||||
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if attention_mask.dim() == 2:
|
||||
if self.config.is_decoder:
|
||||
batch_size, seq_length = input_ids.size()
|
||||
seq_ids = torch.arange(seq_length, device=input_ids.device)
|
||||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||
else:
|
||||
extended_attention_mask = attention_mask[:, None, None, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
||||
if encoder_attention_mask.dim() == 3:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
||||
if encoder_attention_mask.dim() == 2:
|
||||
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
||||
|
||||
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
if head_mask is not None:
|
||||
if head_mask.dim() == 1:
|
||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
||||
elif head_mask.dim() == 2:
|
||||
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
||||
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
||||
else:
|
||||
head_mask = [None] * self.config.num_hidden_layers
|
||||
|
||||
all_hidden_states = ()
|
||||
all_attentions = ()
|
||||
position_bias = None
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_outputs = layer_module(hidden_states,
|
||||
attention_mask=extended_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
head_mask=head_mask[i])
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if self.output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
layer_output = self.dropout(hidden_states)
|
||||
|
||||
# Add last layer
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if self.output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if self.output_attentions:
|
||||
outputs = outputs + (all_attentions,)
|
||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||
|
||||
|
||||
class T5PreTrainedModel(PreTrainedEncoderDecoder):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
config_class = T5Config
|
||||
pretrained_model_archive_map = T5_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_t5
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
@@ -238,19 +599,23 @@ class T5Model(T5PreTrainedModel):
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(T5Model, self).__init__(config)
|
||||
self.shared = nn.Embeddings(config.vocab_size, config.d_model)
|
||||
|
||||
self.embeddings = T5Embeddings(config)
|
||||
self.encoder = T5Encoder(config)
|
||||
self.pooler = T5Pooler(config)
|
||||
encoder_config = copy.deepcopy(config)
|
||||
self.encoder = T5Stack(encoder_config)
|
||||
|
||||
decoder_config = copy.deepcopy(config)
|
||||
decoder_config.is_decoder = True
|
||||
self.decoder = T5Stack(decoder_config)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@property
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.word_embeddings
|
||||
return self.shared
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.embeddings.word_embeddings = new_embeddings
|
||||
self.shared = new_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
@@ -260,50 +625,36 @@ class T5Model(T5PreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
def forward(self, encoder_input_ids, decoder_input_ids, **kwargs):
|
||||
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
|
||||
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
|
||||
# that apply to the model as whole.
|
||||
# We let the specific kwargs override the common ones in case of conflict.
|
||||
kwargs_common = dict((k, v) for k, v in kwargs.items()
|
||||
if not k.startswith("encoder_") and not k.startswith("decoder_"))
|
||||
kwargs_decoder = kwargs_common.copy()
|
||||
kwargs_encoder = kwargs_common.copy()
|
||||
kwargs_encoder.update(dict((k[len("encoder_") :], v) for k, v in kwargs.items() if k.startswith("encoder_")))
|
||||
kwargs_decoder.update(dict((k[len("decoder_") :], v) for k, v in kwargs.items() if k.startswith("decoder_")))
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
if head_mask is not None:
|
||||
if head_mask.dim() == 1:
|
||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
||||
elif head_mask.dim() == 2:
|
||||
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
||||
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
||||
# Encode if needed (training, first prediction pass)
|
||||
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
|
||||
if encoder_hidden_states is None:
|
||||
encoder_inputs_ids = kwargs_encoder.pop("input_ids")
|
||||
hidden_states = self.shared(encoder_inputs_ids) # Convert inputs in embeddings
|
||||
encoder_outputs = self.encoder(hidden_states, **kwargs_encoder)
|
||||
encoder_hidden_states = encoder_outputs[0]
|
||||
else:
|
||||
head_mask = [None] * self.config.num_hidden_layers
|
||||
encoder_outputs = ()
|
||||
|
||||
##################################
|
||||
# Replace this with your model code
|
||||
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
||||
encoder_outputs = self.encoder(embedding_output, extended_attention_mask, head_mask=head_mask)
|
||||
sequence_output = encoder_outputs[0]
|
||||
outputs = (sequence_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
||||
# Decode
|
||||
decoder_inputs_ids = kwargs_decoder.pop("input_ids")
|
||||
hidden_states = self.shared(decoder_inputs_ids) # Convert inputs in embeddings
|
||||
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
|
||||
kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None)
|
||||
decoder_outputs = self.decoder(hidden_states, **kwargs_decoder)
|
||||
|
||||
return outputs # sequence_output, (hidden_states), (attentions)
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
|
||||
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """,
|
||||
@@ -342,7 +693,7 @@ class T5WithLMHead(T5PreTrainedModel):
|
||||
super(T5ForMaskedLM, self).__init__(config)
|
||||
|
||||
self.transformer = T5Model(config)
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
|
||||
self.lm_head = nn.Linear(config.d_model, config.vocab_size)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user