[All models] Extend config.output_attentions with output_attentions function arguments (#4538)

* DOC: Replace instances of ``config.output_attentions`` with function argument ``output_attentions``

* DOC: Apply Black Formatting

* Fix errors where output_attentions was undefined

* Remove output_attentions in classes per review

* Fix regressions on tests having `output_attention`

* Fix further regressions in tests relating to `output_attentions`

Ensure proper propagation of `output_attentions` as a function parameter
to all model subclasses

* Fix more regressions in `test_output_attentions`

* Fix issues with BertEncoder

* Rename related variables to `output_attentions`

* fix pytorch tests

* fix bert and gpt2 tf

* Fix most TF tests for `test_output_attentions`

* Fix linter errors and more TF tests

* fix conflicts

* DOC: Apply Black Formatting

* Fix errors where output_attentions was undefined

* Remove output_attentions in classes per review

* Fix regressions on tests having `output_attention`

* fix conflicts

* fix conflicts

* fix conflicts

* fix conflicts

* fix pytorch tests

* fix conflicts

* fix conflicts

* Fix linter errors and more TF tests

* fix tf tests

* make style

* fix isort

* improve output_attentions

* improve tensorflow

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Bharat Raghunathan
2020-06-10 03:09:06 +05:30
committed by GitHub
parent f90bc44d9a
commit 6e603cb789
38 changed files with 1108 additions and 549 deletions

View File

@@ -83,9 +83,8 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N
class MultiHeadAttention(torch.nn.Module):
def __init__(self, d_model_size, num_heads, output_attentions=False):
def __init__(self, d_model_size, num_heads):
super().__init__()
self.output_attentions = output_attentions
self.num_heads = num_heads
self.d_model_size = d_model_size
@@ -101,7 +100,18 @@ class MultiHeadAttention(torch.nn.Module):
x = x.reshape(batch_size, -1, self.num_heads, self.depth)
return x.permute([0, 2, 1, 3])
def forward(self, v, k, q, mask, layer_past=None, attention_mask=None, head_mask=None, use_cache=False):
def forward(
self,
v,
k,
q,
mask,
layer_past=None,
attention_mask=None,
head_mask=None,
use_cache=False,
output_attentions=False,
):
batch_size = q.shape[0]
q = self.Wq(q)
@@ -128,7 +138,7 @@ class MultiHeadAttention(torch.nn.Module):
output = self.dense(original_size_attention)
outputs = (output, present)
if self.output_attentions:
if output_attentions:
outputs = outputs + (attn,)
return outputs
@@ -138,10 +148,10 @@ def point_wise_feed_forward_network(d_model_size, dff):
class EncoderLayer(torch.nn.Module):
def __init__(self, d_model_size, num_heads, dff, rate=0.1, output_attentions=False):
def __init__(self, d_model_size, num_heads, dff, rate=0.1):
super().__init__()
self.multi_head_attention = MultiHeadAttention(d_model_size, num_heads, output_attentions)
self.multi_head_attention = MultiHeadAttention(d_model_size, num_heads)
self.ffn = point_wise_feed_forward_network(d_model_size, dff)
self.layernorm1 = torch.nn.LayerNorm(d_model_size, eps=1e-6)
@@ -150,7 +160,9 @@ class EncoderLayer(torch.nn.Module):
self.dropout1 = torch.nn.Dropout(rate)
self.dropout2 = torch.nn.Dropout(rate)
def forward(self, x, mask, layer_past=None, attention_mask=None, head_mask=None, use_cache=False):
def forward(
self, x, mask, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False
):
normed = self.layernorm1(x)
attn_outputs = self.multi_head_attention(
normed,
@@ -161,6 +173,7 @@ class EncoderLayer(torch.nn.Module):
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attn_output = attn_outputs[0]
attn_output = self.dropout1(attn_output)
@@ -264,7 +277,6 @@ class CTRLModel(CTRLPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.d_model_size = config.n_embd
self.num_layers = config.n_layer
@@ -275,10 +287,7 @@ class CTRLModel(CTRLPreTrainedModel):
self.dropout = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList(
[
EncoderLayer(config.n_embd, config.n_head, config.dff, config.resid_pdrop, config.output_attentions)
for _ in range(config.n_layer)
]
[EncoderLayer(config.n_embd, config.n_head, config.dff, config.resid_pdrop) for _ in range(config.n_layer)]
)
self.layernorm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
@@ -308,6 +317,7 @@ class CTRLModel(CTRLPreTrainedModel):
head_mask=None,
inputs_embeds=None,
use_cache=True,
output_attentions=None,
):
r"""
Return:
@@ -322,7 +332,7 @@ class CTRLModel(CTRLPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
@@ -343,6 +353,7 @@ class CTRLModel(CTRLPreTrainedModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@@ -424,12 +435,13 @@ class CTRLModel(CTRLPreTrainedModel):
attention_mask=attention_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states, present = outputs[:2]
if use_cache is True:
presents = presents + (present,)
if self.output_attentions:
if output_attentions:
all_attentions.append(outputs[2])
hidden_states = self.layernorm(hidden_states)
@@ -442,7 +454,7 @@ class CTRLModel(CTRLPreTrainedModel):
outputs = outputs + (presents,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
if output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
@@ -485,6 +497,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
inputs_embeds=None,
labels=None,
use_cache=True,
output_attentions=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
@@ -508,7 +521,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
@@ -537,6 +550,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = transformer_outputs[0]