adding T5 model
This commit is contained in:
@@ -217,9 +217,7 @@ class PreTrainedEncoderDecoder(nn.Module):
|
|||||||
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
|
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
|
||||||
if encoder_hidden_states is None:
|
if encoder_hidden_states is None:
|
||||||
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
|
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
|
||||||
encoder_hidden_states = encoder_outputs[
|
encoder_hidden_states = encoder_outputs[0]
|
||||||
0
|
|
||||||
] # output the last layer hidden state
|
|
||||||
else:
|
else:
|
||||||
encoder_outputs = ()
|
encoder_outputs = ()
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2018 T5 Authors and HuggingFace Inc. team.
|
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -20,11 +20,14 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import math
|
||||||
import sys
|
import sys
|
||||||
|
import itertools
|
||||||
from io import open
|
from io import open
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
from torch.nn import CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from .modeling_utils import PreTrainedModel, prune_linear_layer
|
from .modeling_utils import PreTrainedModel, prune_linear_layer
|
||||||
@@ -119,31 +122,389 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
|
|||||||
# - PreTrainedModel for the models (it-self a sub-class of torch.nn.Module)
|
# - PreTrainedModel for the models (it-self a sub-class of torch.nn.Module)
|
||||||
####################################################
|
####################################################
|
||||||
|
|
||||||
class T5Layer(nn.Module):
|
class T5DenseReluDense(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(T5Layer, self).__init__()
|
super(T5DenseReluDense, self).__init__()
|
||||||
self.attention = T5Attention(config)
|
self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
|
||||||
self.intermediate = T5Intermediate(config)
|
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
|
||||||
self.output = T5Output(config)
|
self.dropout = nn.Dropout(config.dropout)
|
||||||
|
|
||||||
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
def forward(self, hidden_states):
|
||||||
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
h = self.wi(hidden_states)
|
||||||
attention_output = attention_outputs[0]
|
h = F.relu(h)
|
||||||
intermediate_output = self.intermediate(attention_output)
|
h = self.dropout(h)
|
||||||
layer_output = self.output(intermediate_output, attention_output)
|
h = self.wo(h)
|
||||||
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class T5LayerFF(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(T5LayerFF, self).__init__()
|
||||||
|
self.DenseReluDense = T5DenseReluDense(config)
|
||||||
|
self.layer_norm = nn.LayerNorm(config.layer_norm_epsilon)
|
||||||
|
self.dropout = nn.Dropout(config.dropout)
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
norm_x = self.layer_norm(hidden_states)
|
||||||
|
y = self.DenseReluDense(norm_x)
|
||||||
|
layer_output = hidden_states + self.dropout(y)
|
||||||
|
return layer_output
|
||||||
|
|
||||||
|
|
||||||
|
class T5Attention(nn.Module):
|
||||||
|
NEW_ID = itertools.count()
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super(T5Attention, self).__init__()
|
||||||
|
self.layer_id = next(T5Attention.NEW_ID)
|
||||||
|
|
||||||
|
self.output_attentions = config.output_attentions
|
||||||
|
self.relative_attention_num_buckets = config.relative_attention_num_buckets
|
||||||
|
self.dim = config.d_model
|
||||||
|
self.n_heads = config.num_heads
|
||||||
|
self.dropout = config.dropout_rate
|
||||||
|
assert self.dim % self.n_heads == 0
|
||||||
|
|
||||||
|
self.q = nn.Linear(self.dim, self.dim, bias=False)
|
||||||
|
self.k = nn.Linear(self.dim, self.dim, bias=False)
|
||||||
|
self.v = nn.Linear(self.dim, self.dim, bias=False)
|
||||||
|
self.o = nn.Linear(self.dim, self.dim, bias=False)
|
||||||
|
|
||||||
|
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
|
||||||
|
self.pruned_heads = set()
|
||||||
|
|
||||||
|
def prune_heads(self, heads):
|
||||||
|
attention_head_size = self.dim // self.n_heads
|
||||||
|
if len(heads) == 0:
|
||||||
|
return
|
||||||
|
mask = torch.ones(self.n_heads, attention_head_size)
|
||||||
|
heads = set(heads) - self.pruned_heads
|
||||||
|
for head in heads:
|
||||||
|
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
|
||||||
|
mask[head] = 0
|
||||||
|
mask = mask.view(-1).contiguous().eq(1)
|
||||||
|
index = torch.arange(len(mask))[mask].long()
|
||||||
|
# Prune linear layers
|
||||||
|
self.q = prune_linear_layer(self.q, index)
|
||||||
|
self.k = prune_linear_layer(self.k, index)
|
||||||
|
self.v = prune_linear_layer(self.v, index)
|
||||||
|
self.o = prune_linear_layer(self.o, index, dim=1)
|
||||||
|
# Update hyper params
|
||||||
|
self.n_heads = self.n_heads - len(heads)
|
||||||
|
self.dim = attention_head_size * self.n_heads
|
||||||
|
self.pruned_heads = self.pruned_heads.union(heads)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _relative_position_bucket(relative_position,
|
||||||
|
bidirectional=True,
|
||||||
|
num_buckets=32,
|
||||||
|
max_distance=128):
|
||||||
|
"""
|
||||||
|
Adapted from Mesh Tensorflow:
|
||||||
|
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
|
||||||
|
|
||||||
|
Translate relative position to a bucket number for relative attention.
|
||||||
|
The relative position is defined as memory_position - query_position, i.e.
|
||||||
|
the distance in tokens from the attending position to the attended-to
|
||||||
|
position. If bidirectional=False, then positive relative positions are
|
||||||
|
invalid.
|
||||||
|
We use smaller buckets for small absolute relative_position and larger buckets
|
||||||
|
for larger absolute relative_positions. All relative positions >=max_distance
|
||||||
|
map to the same bucket. All relative positions <=-max_distance map to the
|
||||||
|
same bucket. This should allow for more graceful generalization to longer
|
||||||
|
sequences than the model has been trained on.
|
||||||
|
Args:
|
||||||
|
relative_position: an int32 Tensor
|
||||||
|
bidirectional: a boolean - whether the attention is bidirectional
|
||||||
|
num_buckets: an integer
|
||||||
|
max_distance: an integer
|
||||||
|
Returns:
|
||||||
|
a Tensor with the same shape as relative_position, containing int32
|
||||||
|
values in the range [0, num_buckets)
|
||||||
|
"""
|
||||||
|
ret = 0
|
||||||
|
n = -relative_position
|
||||||
|
if bidirectional:
|
||||||
|
num_buckets //= 2
|
||||||
|
ret += (n < 0).to(torch.long) * num_buckets # mtf.to_int32(mtf.less(n, 0)) * num_buckets
|
||||||
|
n = torch.abs(n)
|
||||||
|
else:
|
||||||
|
n = torch.max(n, 0)
|
||||||
|
# now n is in the range [0, inf)
|
||||||
|
|
||||||
|
# half of the buckets are for exact increments in positions
|
||||||
|
max_exact = num_buckets // 2
|
||||||
|
is_small = (n < max_exact)
|
||||||
|
|
||||||
|
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
||||||
|
val_if_large = max_exact + (
|
||||||
|
torch.log(n.float() / max_exact)
|
||||||
|
/ math.log(max_distance / max_exact) * (num_buckets - max_exact)).to(torch.long)
|
||||||
|
val_if_large = torch.min(val_if_large, num_buckets - 1)
|
||||||
|
|
||||||
|
ret += torch.where(is_small, n, val_if_large)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def compute_bias(self, qlen, klen):
|
||||||
|
""" Compute binned relative position bias """
|
||||||
|
context_position = torch.arange(qlen, dtype=torch.long)[:, None]
|
||||||
|
memory_position = torch.arange(klen, dtype=torch.long)[None, :]
|
||||||
|
relative_position = memory_position - context_position # shape (qlen, klen)
|
||||||
|
rp_bucket = self._relative_position_bucket(relative_position,
|
||||||
|
bidirectional=not self.is_decoder,
|
||||||
|
num_buckets=self.relative_attention_num_buckets)
|
||||||
|
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
|
||||||
|
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def forward(self, input, mask, kv=None, position_bias=None, cache=None, head_mask=None):
|
||||||
|
"""
|
||||||
|
Self-attention (if kv is None) or attention over source sentence (provided by kv).
|
||||||
|
"""
|
||||||
|
# Input is (bs, qlen, dim)
|
||||||
|
# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
|
||||||
|
bs, qlen, dim = input.size()
|
||||||
|
if kv is None:
|
||||||
|
klen = qlen if cache is None else cache['slen'] + qlen
|
||||||
|
else:
|
||||||
|
klen = kv.size(1)
|
||||||
|
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
|
||||||
|
n_heads = self.n_heads
|
||||||
|
dim_per_head = self.dim // n_heads
|
||||||
|
mask_reshape = (bs, 1, qlen, klen) if mask.dim() == 3 else (bs, 1, 1, klen)
|
||||||
|
|
||||||
|
def shape(x):
|
||||||
|
""" projection """
|
||||||
|
return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
|
||||||
|
|
||||||
|
def unshape(x):
|
||||||
|
""" compute context """
|
||||||
|
return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
|
||||||
|
|
||||||
|
q = shape(self.q(input)) # (bs, n_heads, qlen, dim_per_head)
|
||||||
|
if kv is None:
|
||||||
|
k = shape(self.k(input)) # (bs, n_heads, qlen, dim_per_head)
|
||||||
|
v = shape(self.v(input)) # (bs, n_heads, qlen, dim_per_head)
|
||||||
|
elif cache is None or self.layer_id not in cache:
|
||||||
|
k = v = kv
|
||||||
|
k = shape(self.k(k)) # (bs, n_heads, qlen, dim_per_head)
|
||||||
|
v = shape(self.v(v)) # (bs, n_heads, qlen, dim_per_head)
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
if self.layer_id in cache:
|
||||||
|
if kv is None:
|
||||||
|
k_, v_ = cache[self.layer_id]
|
||||||
|
k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head)
|
||||||
|
v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head)
|
||||||
|
else:
|
||||||
|
k, v = cache[self.layer_id]
|
||||||
|
cache[self.layer_id] = (k, v)
|
||||||
|
|
||||||
|
# q = q / math.sqrt(dim_per_head) # No scaling in T5
|
||||||
|
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
|
||||||
|
|
||||||
|
if position_bias is None:
|
||||||
|
position_bias = self.compute_bias(qlen, klen)
|
||||||
|
scores += position_bias
|
||||||
|
|
||||||
|
mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
|
||||||
|
scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
|
||||||
|
|
||||||
|
weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen)
|
||||||
|
weights = F.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen)
|
||||||
|
|
||||||
|
# Mask heads if we want to
|
||||||
|
if head_mask is not None:
|
||||||
|
weights = weights * head_mask
|
||||||
|
|
||||||
|
context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head)
|
||||||
|
context = unshape(context) # (bs, qlen, dim)
|
||||||
|
|
||||||
|
context = self.o(context)
|
||||||
|
|
||||||
|
outputs = (context,)
|
||||||
|
if self.output_attentions:
|
||||||
|
outputs = outputs + (weights,)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class T5LayerSelfAttention(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(T5LayerSelfAttention, self).__init__()
|
||||||
|
self.SelfAttention = T5Attention(config)
|
||||||
|
self.layer_norm = nn.LayerNorm(config.layer_norm_epsilon)
|
||||||
|
self.dropout = nn.Dropout(config.dropout)
|
||||||
|
|
||||||
class T5PreTrainedModel(PreTrainedModel):
|
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
||||||
|
norm_x = self.layer_norm(hidden_states)
|
||||||
|
attention_output = self.SelfAttention(norm_x,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask)
|
||||||
|
y = attention_output[0]
|
||||||
|
layer_output = hidden_states + self.dropout(y)
|
||||||
|
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class T5LayerCrossAttention(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(T5LayerCrossAttention, self).__init__()
|
||||||
|
self.EncDecAttention = T5Attention(config)
|
||||||
|
self.layer_norm = nn.LayerNorm(config.layer_norm_epsilon)
|
||||||
|
self.dropout = nn.Dropout(config.dropout)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, kv, attention_mask=None, head_mask=None):
|
||||||
|
norm_x = self.layer_norm(hidden_states)
|
||||||
|
attention_output = self.EncDecAttention(norm_x,
|
||||||
|
kv=kv,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask)
|
||||||
|
y = attention_output[0]
|
||||||
|
layer_output = hidden_states + self.dropout(y)
|
||||||
|
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class T5Block(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(T5Block, self).__init__()
|
||||||
|
self.is_decoder = config.is_decoder
|
||||||
|
self.layer_000 = T5LayerSelfAttention(config)
|
||||||
|
if self.is_decoder:
|
||||||
|
self.layer_001 = T5LayerCrossAttention(config)
|
||||||
|
self.layer_002 = T5LayerFF(config)
|
||||||
|
else:
|
||||||
|
self.layer_001 = T5LayerFF(config)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, attention_mask=None,
|
||||||
|
encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None):
|
||||||
|
self_attention_outputs = self.layer_000(hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask)
|
||||||
|
hidden_states = self_attention_outputs[0]
|
||||||
|
outputs = self_attention_outputs[1:]
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
cross_attention_outputs = self.layer_001(hidden_states,
|
||||||
|
kv=encoder_hidden_states,
|
||||||
|
attention_mask=encoder_attention_mask,
|
||||||
|
head_mask=head_mask)
|
||||||
|
hidden_states = cross_attention_outputs[0]
|
||||||
|
outputs = cross_attention_outputs[1:] + outputs
|
||||||
|
hidden_states = self.layer_002(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states = self.layer_001(hidden_states)
|
||||||
|
|
||||||
|
outputs = (hidden_states,) + outputs # add attentions if we output them
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class T5Stack(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(T5Stack, self).__init__()
|
||||||
|
self.blocks = nn.ModuleList([T5Block(config) for _ in range(config.num_layers)])
|
||||||
|
self.final_layer_norm = nn.LayerNorm(config.layer_norm_epsilon)
|
||||||
|
self.dropout = nn.Dropout(config.dropout)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
encoder_attention_mask=None,
|
||||||
|
head_mask=None):
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
|
||||||
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||||
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||||
|
if attention_mask.dim() == 3:
|
||||||
|
extended_attention_mask = attention_mask[:, None, :, :]
|
||||||
|
|
||||||
|
# Provided a padding mask of dimensions [batch_size, seq_length]
|
||||||
|
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
||||||
|
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
|
if attention_mask.dim() == 2:
|
||||||
|
if self.config.is_decoder:
|
||||||
|
batch_size, seq_length = input_ids.size()
|
||||||
|
seq_ids = torch.arange(seq_length, device=input_ids.device)
|
||||||
|
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
||||||
|
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
||||||
|
else:
|
||||||
|
extended_attention_mask = attention_mask[:, None, None, :]
|
||||||
|
|
||||||
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
|
# masked positions, this operation will create a tensor which is 0.0 for
|
||||||
|
# positions we want to attend and -10000.0 for masked positions.
|
||||||
|
# Since we are adding it to the raw scores before the softmax, this is
|
||||||
|
# effectively the same as removing these entirely.
|
||||||
|
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||||
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||||
|
|
||||||
|
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||||
|
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
||||||
|
if encoder_attention_mask.dim() == 3:
|
||||||
|
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
||||||
|
if encoder_attention_mask.dim() == 2:
|
||||||
|
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
||||||
|
|
||||||
|
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||||
|
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
||||||
|
|
||||||
|
# Prepare head mask if needed
|
||||||
|
# 1.0 in head_mask indicate we keep the head
|
||||||
|
# attention_probs has shape bsz x n_heads x N x N
|
||||||
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||||
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||||
|
if head_mask is not None:
|
||||||
|
if head_mask.dim() == 1:
|
||||||
|
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
||||||
|
elif head_mask.dim() == 2:
|
||||||
|
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
||||||
|
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
||||||
|
else:
|
||||||
|
head_mask = [None] * self.config.num_hidden_layers
|
||||||
|
|
||||||
|
all_hidden_states = ()
|
||||||
|
all_attentions = ()
|
||||||
|
position_bias = None
|
||||||
|
for i, layer_module in enumerate(self.layer):
|
||||||
|
if self.output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
layer_outputs = layer_module(hidden_states,
|
||||||
|
attention_mask=extended_attention_mask,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_extended_attention_mask,
|
||||||
|
head_mask=head_mask[i])
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if self.output_attentions:
|
||||||
|
all_attentions = all_attentions + (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.final_layer_norm(hidden_states)
|
||||||
|
layer_output = self.dropout(hidden_states)
|
||||||
|
|
||||||
|
# Add last layer
|
||||||
|
if self.output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
if self.output_hidden_states:
|
||||||
|
outputs = outputs + (all_hidden_states,)
|
||||||
|
if self.output_attentions:
|
||||||
|
outputs = outputs + (all_attentions,)
|
||||||
|
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||||
|
|
||||||
|
|
||||||
|
class T5PreTrainedModel(PreTrainedEncoderDecoder):
|
||||||
""" An abstract class to handle weights initialization and
|
""" An abstract class to handle weights initialization and
|
||||||
a simple interface for dowloading and loading pretrained models.
|
a simple interface for dowloading and loading pretrained models.
|
||||||
"""
|
"""
|
||||||
config_class = T5Config
|
config_class = T5Config
|
||||||
pretrained_model_archive_map = T5_PRETRAINED_MODEL_ARCHIVE_MAP
|
pretrained_model_archive_map = T5_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
load_tf_weights = load_tf_weights_in_t5
|
load_tf_weights = load_tf_weights_in_t5
|
||||||
base_model_prefix = "transformer"
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
""" Initialize the weights """
|
""" Initialize the weights """
|
||||||
@@ -238,19 +599,23 @@ class T5Model(T5PreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(T5Model, self).__init__(config)
|
super(T5Model, self).__init__(config)
|
||||||
|
self.shared = nn.Embeddings(config.vocab_size, config.d_model)
|
||||||
|
|
||||||
self.embeddings = T5Embeddings(config)
|
encoder_config = copy.deepcopy(config)
|
||||||
self.encoder = T5Encoder(config)
|
self.encoder = T5Stack(encoder_config)
|
||||||
self.pooler = T5Pooler(config)
|
|
||||||
|
decoder_config = copy.deepcopy(config)
|
||||||
|
decoder_config.is_decoder = True
|
||||||
|
self.decoder = T5Stack(decoder_config)
|
||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.embeddings.word_embeddings
|
return self.shared
|
||||||
|
|
||||||
def set_input_embeddings(self, new_embeddings):
|
def set_input_embeddings(self, new_embeddings):
|
||||||
self.embeddings.word_embeddings = new_embeddings
|
self.shared = new_embeddings
|
||||||
|
|
||||||
def _prune_heads(self, heads_to_prune):
|
def _prune_heads(self, heads_to_prune):
|
||||||
""" Prunes heads of the model.
|
""" Prunes heads of the model.
|
||||||
@@ -260,50 +625,36 @@ class T5Model(T5PreTrainedModel):
|
|||||||
for layer, heads in heads_to_prune.items():
|
for layer, heads in heads_to_prune.items():
|
||||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||||
|
|
||||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
def forward(self, encoder_input_ids, decoder_input_ids, **kwargs):
|
||||||
if attention_mask is None:
|
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
|
||||||
attention_mask = torch.ones_like(input_ids)
|
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
|
||||||
if token_type_ids is None:
|
# that apply to the model as whole.
|
||||||
token_type_ids = torch.zeros_like(input_ids)
|
# We let the specific kwargs override the common ones in case of conflict.
|
||||||
|
kwargs_common = dict((k, v) for k, v in kwargs.items()
|
||||||
|
if not k.startswith("encoder_") and not k.startswith("decoder_"))
|
||||||
|
kwargs_decoder = kwargs_common.copy()
|
||||||
|
kwargs_encoder = kwargs_common.copy()
|
||||||
|
kwargs_encoder.update(dict((k[len("encoder_") :], v) for k, v in kwargs.items() if k.startswith("encoder_")))
|
||||||
|
kwargs_decoder.update(dict((k[len("decoder_") :], v) for k, v in kwargs.items() if k.startswith("decoder_")))
|
||||||
|
|
||||||
# We create a 3D attention mask from a 2D tensor mask.
|
# Encode if needed (training, first prediction pass)
|
||||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
|
||||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
if encoder_hidden_states is None:
|
||||||
# this attention mask is more simple than the triangular masking of causal attention
|
encoder_inputs_ids = kwargs_encoder.pop("input_ids")
|
||||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
hidden_states = self.shared(encoder_inputs_ids) # Convert inputs in embeddings
|
||||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
encoder_outputs = self.encoder(hidden_states, **kwargs_encoder)
|
||||||
|
encoder_hidden_states = encoder_outputs[0]
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
|
||||||
# positions we want to attend and -10000.0 for masked positions.
|
|
||||||
# Since we are adding it to the raw scores before the softmax, this is
|
|
||||||
# effectively the same as removing these entirely.
|
|
||||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
|
||||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
|
||||||
|
|
||||||
# Prepare head mask if needed
|
|
||||||
# 1.0 in head_mask indicate we keep the head
|
|
||||||
# attention_probs has shape bsz x n_heads x N x N
|
|
||||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
|
||||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
|
||||||
if head_mask is not None:
|
|
||||||
if head_mask.dim() == 1:
|
|
||||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
|
||||||
elif head_mask.dim() == 2:
|
|
||||||
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
|
||||||
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
|
||||||
else:
|
else:
|
||||||
head_mask = [None] * self.config.num_hidden_layers
|
encoder_outputs = ()
|
||||||
|
|
||||||
##################################
|
# Decode
|
||||||
# Replace this with your model code
|
decoder_inputs_ids = kwargs_decoder.pop("input_ids")
|
||||||
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
hidden_states = self.shared(decoder_inputs_ids) # Convert inputs in embeddings
|
||||||
encoder_outputs = self.encoder(embedding_output, extended_attention_mask, head_mask=head_mask)
|
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
|
||||||
sequence_output = encoder_outputs[0]
|
kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None)
|
||||||
outputs = (sequence_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
decoder_outputs = self.decoder(hidden_states, **kwargs_decoder)
|
||||||
|
|
||||||
return outputs # sequence_output, (hidden_states), (attentions)
|
return decoder_outputs + encoder_outputs
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """,
|
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """,
|
||||||
@@ -342,7 +693,7 @@ class T5WithLMHead(T5PreTrainedModel):
|
|||||||
super(T5ForMaskedLM, self).__init__(config)
|
super(T5ForMaskedLM, self).__init__(config)
|
||||||
|
|
||||||
self.transformer = T5Model(config)
|
self.transformer = T5Model(config)
|
||||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
|
self.lm_head = nn.Linear(config.d_model, config.vocab_size)
|
||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user