Black 20 release

This commit is contained in:
Lysandre
2020-08-26 17:20:22 +02:00
parent e78c110338
commit a75c64d80c
191 changed files with 4807 additions and 3503 deletions

View File

@@ -105,10 +105,17 @@ class BertAbs(BertAbsPreTrainedModel):
p.data.zero_()
def forward(
self, encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask,
self,
encoder_input_ids,
decoder_input_ids,
token_type_ids,
encoder_attention_mask,
decoder_attention_mask,
):
encoder_output = self.bert(
input_ids=encoder_input_ids, token_type_ids=token_type_ids, attention_mask=encoder_attention_mask,
input_ids=encoder_input_ids,
token_type_ids=token_type_ids,
attention_mask=encoder_attention_mask,
)
encoder_hidden_states = encoder_output[0]
dec_state = self.decoder.init_decoder_state(encoder_input_ids, encoder_hidden_states)
@@ -117,8 +124,7 @@ class BertAbs(BertAbsPreTrainedModel):
class Bert(nn.Module):
""" This class is not really necessary and should probably disappear.
"""
"""This class is not really necessary and should probably disappear."""
def __init__(self):
super().__init__()
@@ -307,7 +313,14 @@ class TransformerDecoderLayer(nn.Module):
self.register_buffer("mask", mask)
def forward(
self, inputs, memory_bank, src_pad_mask, tgt_pad_mask, previous_input=None, layer_cache=None, step=None,
self,
inputs,
memory_bank,
src_pad_mask,
tgt_pad_mask,
previous_input=None,
layer_cache=None,
step=None,
):
"""
Args:
@@ -331,13 +344,25 @@ class TransformerDecoderLayer(nn.Module):
all_input = torch.cat((previous_input, input_norm), dim=1)
dec_mask = None
query = self.self_attn(all_input, all_input, input_norm, mask=dec_mask, layer_cache=layer_cache, type="self",)
query = self.self_attn(
all_input,
all_input,
input_norm,
mask=dec_mask,
layer_cache=layer_cache,
type="self",
)
query = self.drop(query) + inputs
query_norm = self.layer_norm_2(query)
mid = self.context_attn(
memory_bank, memory_bank, query_norm, mask=src_pad_mask, layer_cache=layer_cache, type="context",
memory_bank,
memory_bank,
query_norm,
mask=src_pad_mask,
layer_cache=layer_cache,
type="context",
)
output = self.feed_forward(self.drop(mid) + query)
@@ -422,7 +447,14 @@ class MultiHeadedAttention(nn.Module):
self.final_linear = nn.Linear(model_dim, model_dim)
def forward(
self, key, value, query, mask=None, layer_cache=None, type=None, predefined_graph_1=None,
self,
key,
value,
query,
mask=None,
layer_cache=None,
type=None,
predefined_graph_1=None,
):
"""
Compute the context vector and the attention vectors.
@@ -628,7 +660,7 @@ def gelu(x):
class PositionwiseFeedForward(nn.Module):
""" A two-layer Feed-Forward-Network with residual layer norm.
"""A two-layer Feed-Forward-Network with residual layer norm.
Args:
d_model (int): the size of input for the first-layer of the FFN.
@@ -770,8 +802,7 @@ class Translator(object):
self.max_length = args.max_length
def translate(self, batch, step, attn_debug=False):
""" Generates summaries from one batch of data.
"""
"""Generates summaries from one batch of data."""
self.model.eval()
with torch.no_grad():
batch_data = self.translate_batch(batch)
@@ -798,8 +829,7 @@ class Translator(object):
# Where the beam search lives
# I have no idea why it is being called from the method above
def _fast_translate_batch(self, batch, max_length, min_length=0):
""" Beam Search using the encoder inputs contained in `batch`.
"""
"""Beam Search using the encoder inputs contained in `batch`."""
# The batch object is funny
# Instead of just looking at the size of the arguments we encapsulate
@@ -981,7 +1011,7 @@ def tile(x, count, dim=0):
class BertSumOptimizer(object):
""" Specific optimizer for BertSum.
"""Specific optimizer for BertSum.
As described in [1], the authors fine-tune BertSum for abstractive
summarization using two Adam Optimizers with different warm-up steps and
@@ -999,10 +1029,16 @@ class BertSumOptimizer(object):
self.optimizers = {
"encoder": torch.optim.Adam(
model.encoder.parameters(), lr=lr["encoder"], betas=(beta_1, beta_2), eps=eps,
model.encoder.parameters(),
lr=lr["encoder"],
betas=(beta_1, beta_2),
eps=eps,
),
"decoder": torch.optim.Adam(
model.decoder.parameters(), lr=lr["decoder"], betas=(beta_1, beta_2), eps=eps,
model.decoder.parameters(),
lr=lr["decoder"],
betas=(beta_1, beta_2),
eps=eps,
),
}