Reformat source code with black.
This is the result of:
$ black --line-length 119 examples templates transformers utils hubconf.py setup.py
There's a lot of fairly long lines in the project. As a consequence, I'm
picking the longest widely accepted line length, 119 characters.
This is also Thomas' preference, because it allows for explicit variable
names, to make the code easier to understand.
This commit is contained in:
@@ -56,40 +56,22 @@ class BertAbs(BertAbsPreTrainedModel):
|
||||
load_bert_pretrained_extractive = True if bert_extractive_checkpoint else False
|
||||
if load_bert_pretrained_extractive:
|
||||
self.bert.model.load_state_dict(
|
||||
dict(
|
||||
[
|
||||
(n[11:], p)
|
||||
for n, p in bert_extractive_checkpoint.items()
|
||||
if n.startswith("bert.model")
|
||||
]
|
||||
),
|
||||
dict([(n[11:], p) for n, p in bert_extractive_checkpoint.items() if n.startswith("bert.model")]),
|
||||
strict=True,
|
||||
)
|
||||
|
||||
self.vocab_size = self.bert.model.config.vocab_size
|
||||
|
||||
if args.max_pos > 512:
|
||||
my_pos_embeddings = nn.Embedding(
|
||||
args.max_pos, self.bert.model.config.hidden_size
|
||||
)
|
||||
my_pos_embeddings.weight.data[
|
||||
:512
|
||||
] = self.bert.model.embeddings.position_embeddings.weight.data
|
||||
my_pos_embeddings.weight.data[
|
||||
512:
|
||||
] = self.bert.model.embeddings.position_embeddings.weight.data[-1][
|
||||
my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size)
|
||||
my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data
|
||||
my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][
|
||||
None, :
|
||||
].repeat(
|
||||
args.max_pos - 512, 1
|
||||
)
|
||||
].repeat(args.max_pos - 512, 1)
|
||||
self.bert.model.embeddings.position_embeddings = my_pos_embeddings
|
||||
tgt_embeddings = nn.Embedding(
|
||||
self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0
|
||||
)
|
||||
tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0)
|
||||
|
||||
tgt_embeddings.weight = copy.deepcopy(
|
||||
self.bert.model.embeddings.word_embeddings.weight
|
||||
)
|
||||
tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight)
|
||||
|
||||
self.decoder = TransformerDecoder(
|
||||
self.args.dec_layers,
|
||||
@@ -102,9 +84,7 @@ class BertAbs(BertAbsPreTrainedModel):
|
||||
)
|
||||
|
||||
gen_func = nn.LogSoftmax(dim=-1)
|
||||
self.generator = nn.Sequential(
|
||||
nn.Linear(args.dec_hidden_size, args.vocab_size), gen_func
|
||||
)
|
||||
self.generator = nn.Sequential(nn.Linear(args.dec_hidden_size, args.vocab_size), gen_func)
|
||||
self.generator[0].weight = self.decoder.embeddings.weight
|
||||
|
||||
load_from_checkpoints = False if checkpoint is None else True
|
||||
@@ -127,25 +107,14 @@ class BertAbs(BertAbsPreTrainedModel):
|
||||
p.data.zero_()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_input_ids,
|
||||
decoder_input_ids,
|
||||
token_type_ids,
|
||||
encoder_attention_mask,
|
||||
decoder_attention_mask,
|
||||
self, encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask,
|
||||
):
|
||||
encoder_output = self.bert(
|
||||
input_ids=encoder_input_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=encoder_attention_mask,
|
||||
input_ids=encoder_input_ids, token_type_ids=token_type_ids, attention_mask=encoder_attention_mask,
|
||||
)
|
||||
encoder_hidden_states = encoder_output[0]
|
||||
dec_state = self.decoder.init_decoder_state(
|
||||
encoder_input_ids, encoder_hidden_states
|
||||
)
|
||||
decoder_outputs, _ = self.decoder(
|
||||
decoder_input_ids[:, :-1], encoder_hidden_states, dec_state
|
||||
)
|
||||
dec_state = self.decoder.init_decoder_state(encoder_input_ids, encoder_hidden_states)
|
||||
decoder_outputs, _ = self.decoder(decoder_input_ids[:, :-1], encoder_hidden_states, dec_state)
|
||||
return decoder_outputs
|
||||
|
||||
|
||||
@@ -162,10 +131,7 @@ class Bert(nn.Module):
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
encoder_outputs, _ = self.model(
|
||||
input_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs
|
||||
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, **kwargs
|
||||
)
|
||||
return encoder_outputs
|
||||
|
||||
@@ -196,10 +162,7 @@ class TransformerDecoder(nn.Module):
|
||||
|
||||
# Build TransformerDecoder.
|
||||
self.transformer_layers = nn.ModuleList(
|
||||
[
|
||||
TransformerDecoderLayer(d_model, heads, d_ff, dropout)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
[TransformerDecoderLayer(d_model, heads, d_ff, dropout) for _ in range(num_layers)]
|
||||
)
|
||||
|
||||
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
||||
@@ -236,20 +199,14 @@ class TransformerDecoder(nn.Module):
|
||||
# Decoder padding mask
|
||||
tgt_words = tgt
|
||||
tgt_batch, tgt_len = tgt_words.size()
|
||||
tgt_pad_mask = (
|
||||
tgt_words.data.eq(padding_idx).unsqueeze(1).expand(tgt_batch, tgt_len, tgt_len)
|
||||
)
|
||||
tgt_pad_mask = tgt_words.data.eq(padding_idx).unsqueeze(1).expand(tgt_batch, tgt_len, tgt_len)
|
||||
|
||||
# Encoder padding mask
|
||||
if memory_mask is not None:
|
||||
src_len = memory_mask.size(-1)
|
||||
src_pad_mask = memory_mask.expand(src_batch, tgt_len, src_len)
|
||||
else:
|
||||
src_pad_mask = (
|
||||
src_words.data.eq(padding_idx)
|
||||
.unsqueeze(1)
|
||||
.expand(src_batch, tgt_len, src_len)
|
||||
)
|
||||
src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1).expand(src_batch, tgt_len, src_len)
|
||||
|
||||
# Pass through the embeddings
|
||||
emb = self.embeddings(input_ids)
|
||||
@@ -271,9 +228,7 @@ class TransformerDecoder(nn.Module):
|
||||
src_pad_mask,
|
||||
tgt_pad_mask,
|
||||
previous_input=prev_layer_input,
|
||||
layer_cache=state.cache["layer_{}".format(i)]
|
||||
if state.cache is not None
|
||||
else None,
|
||||
layer_cache=state.cache["layer_{}".format(i)] if state.cache is not None else None,
|
||||
step=step,
|
||||
)
|
||||
if state.cache is None:
|
||||
@@ -303,9 +258,7 @@ class PositionalEncoding(nn.Module):
|
||||
def __init__(self, dropout, dim, max_len=5000):
|
||||
pe = torch.zeros(max_len, dim)
|
||||
position = torch.arange(0, max_len).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
(torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim))
|
||||
)
|
||||
div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim)))
|
||||
pe[:, 0::2] = torch.sin(position.float() * div_term)
|
||||
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
@@ -356,14 +309,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
self.register_buffer("mask", mask)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs,
|
||||
memory_bank,
|
||||
src_pad_mask,
|
||||
tgt_pad_mask,
|
||||
previous_input=None,
|
||||
layer_cache=None,
|
||||
step=None,
|
||||
self, inputs, memory_bank, src_pad_mask, tgt_pad_mask, previous_input=None, layer_cache=None, step=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -380,34 +326,20 @@ class TransformerDecoderLayer(nn.Module):
|
||||
* all_input `[batch_size x current_step x model_dim]`
|
||||
|
||||
"""
|
||||
dec_mask = torch.gt(
|
||||
tgt_pad_mask + self.mask[:, : tgt_pad_mask.size(1), : tgt_pad_mask.size(1)], 0
|
||||
)
|
||||
dec_mask = torch.gt(tgt_pad_mask + self.mask[:, : tgt_pad_mask.size(1), : tgt_pad_mask.size(1)], 0)
|
||||
input_norm = self.layer_norm_1(inputs)
|
||||
all_input = input_norm
|
||||
if previous_input is not None:
|
||||
all_input = torch.cat((previous_input, input_norm), dim=1)
|
||||
dec_mask = None
|
||||
|
||||
query = self.self_attn(
|
||||
all_input,
|
||||
all_input,
|
||||
input_norm,
|
||||
mask=dec_mask,
|
||||
layer_cache=layer_cache,
|
||||
type="self",
|
||||
)
|
||||
query = self.self_attn(all_input, all_input, input_norm, mask=dec_mask, layer_cache=layer_cache, type="self",)
|
||||
|
||||
query = self.drop(query) + inputs
|
||||
|
||||
query_norm = self.layer_norm_2(query)
|
||||
mid = self.context_attn(
|
||||
memory_bank,
|
||||
memory_bank,
|
||||
query_norm,
|
||||
mask=src_pad_mask,
|
||||
layer_cache=layer_cache,
|
||||
type="context",
|
||||
memory_bank, memory_bank, query_norm, mask=src_pad_mask, layer_cache=layer_cache, type="context",
|
||||
)
|
||||
output = self.feed_forward(self.drop(mid) + query)
|
||||
|
||||
@@ -492,14 +424,7 @@ class MultiHeadedAttention(nn.Module):
|
||||
self.final_linear = nn.Linear(model_dim, model_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
key,
|
||||
value,
|
||||
query,
|
||||
mask=None,
|
||||
layer_cache=None,
|
||||
type=None,
|
||||
predefined_graph_1=None,
|
||||
self, key, value, query, mask=None, layer_cache=None, type=None, predefined_graph_1=None,
|
||||
):
|
||||
"""
|
||||
Compute the context vector and the attention vectors.
|
||||
@@ -531,11 +456,7 @@ class MultiHeadedAttention(nn.Module):
|
||||
|
||||
def unshape(x):
|
||||
""" compute context """
|
||||
return (
|
||||
x.transpose(1, 2)
|
||||
.contiguous()
|
||||
.view(batch_size, -1, head_count * dim_per_head)
|
||||
)
|
||||
return x.transpose(1, 2).contiguous().view(batch_size, -1, head_count * dim_per_head)
|
||||
|
||||
# 1) Project key, value, and query.
|
||||
if layer_cache is not None:
|
||||
@@ -554,9 +475,7 @@ class MultiHeadedAttention(nn.Module):
|
||||
if layer_cache["self_keys"] is not None:
|
||||
key = torch.cat((layer_cache["self_keys"].to(device), key), dim=2)
|
||||
if layer_cache["self_values"] is not None:
|
||||
value = torch.cat(
|
||||
(layer_cache["self_values"].to(device), value), dim=2
|
||||
)
|
||||
value = torch.cat((layer_cache["self_values"].to(device), value), dim=2)
|
||||
layer_cache["self_keys"] = key
|
||||
layer_cache["self_values"] = value
|
||||
elif type == "context":
|
||||
@@ -637,13 +556,9 @@ class DecoderState(object):
|
||||
sizes = e.size()
|
||||
br = sizes[1]
|
||||
if len(sizes) == 3:
|
||||
sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2])[
|
||||
:, :, idx
|
||||
]
|
||||
sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2])[:, :, idx]
|
||||
else:
|
||||
sent_states = e.view(
|
||||
sizes[0], beam_size, br // beam_size, sizes[2], sizes[3]
|
||||
)[:, :, idx]
|
||||
sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2], sizes[3])[:, :, idx]
|
||||
|
||||
sent_states.data.copy_(sent_states.data.index_select(1, positions))
|
||||
|
||||
@@ -716,11 +631,7 @@ class TransformerDecoderState(DecoderState):
|
||||
|
||||
|
||||
def gelu(x):
|
||||
return (
|
||||
0.5
|
||||
* x
|
||||
* (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
)
|
||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
|
||||
class PositionwiseFeedForward(nn.Module):
|
||||
@@ -758,9 +669,7 @@ class PositionwiseFeedForward(nn.Module):
|
||||
def build_predictor(args, tokenizer, symbols, model, logger=None):
|
||||
# we should be able to refactor the global scorer a lot
|
||||
scorer = GNMTGlobalScorer(args.alpha, length_penalty="wu")
|
||||
translator = Translator(
|
||||
args, model, tokenizer, symbols, global_scorer=scorer, logger=logger
|
||||
)
|
||||
translator = Translator(args, model, tokenizer, symbols, global_scorer=scorer, logger=logger)
|
||||
return translator
|
||||
|
||||
|
||||
@@ -891,9 +800,7 @@ class Translator(object):
|
||||
Shouldn't need the original dataset.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
return self._fast_translate_batch(
|
||||
batch, self.max_length, min_length=self.min_length
|
||||
)
|
||||
return self._fast_translate_batch(batch, self.max_length, min_length=self.min_length)
|
||||
|
||||
# Where the beam search lives
|
||||
# I have no idea why it is being called from the method above
|
||||
@@ -912,26 +819,18 @@ class Translator(object):
|
||||
mask_src = batch.mask_src
|
||||
|
||||
src_features = self.model.bert(src, segs, mask_src)
|
||||
dec_states = self.model.decoder.init_decoder_state(
|
||||
src, src_features, with_cache=True
|
||||
)
|
||||
dec_states = self.model.decoder.init_decoder_state(src, src_features, with_cache=True)
|
||||
device = src_features.device
|
||||
|
||||
# Tile states and memory beam_size times.
|
||||
dec_states.map_batch_fn(lambda state, dim: tile(state, beam_size, dim=dim))
|
||||
src_features = tile(src_features, beam_size, dim=0)
|
||||
batch_offset = torch.arange(batch_size, dtype=torch.long, device=device)
|
||||
beam_offset = torch.arange(
|
||||
0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=device
|
||||
)
|
||||
alive_seq = torch.full(
|
||||
[batch_size * beam_size, 1], self.start_token, dtype=torch.long, device=device
|
||||
)
|
||||
beam_offset = torch.arange(0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=device)
|
||||
alive_seq = torch.full([batch_size * beam_size, 1], self.start_token, dtype=torch.long, device=device)
|
||||
|
||||
# Give full probability to the first beam on the first step.
|
||||
topk_log_probs = torch.tensor(
|
||||
[0.0] + [float("-inf")] * (beam_size - 1), device=device
|
||||
).repeat(batch_size)
|
||||
topk_log_probs = torch.tensor([0.0] + [float("-inf")] * (beam_size - 1), device=device).repeat(batch_size)
|
||||
|
||||
# Structure that holds finished hypotheses.
|
||||
hypotheses = [[] for _ in range(batch_size)] # noqa: F812
|
||||
@@ -948,9 +847,7 @@ class Translator(object):
|
||||
# Decoder forward.
|
||||
decoder_input = decoder_input.transpose(0, 1)
|
||||
|
||||
dec_out, dec_states = self.model.decoder(
|
||||
decoder_input, src_features, dec_states, step=step
|
||||
)
|
||||
dec_out, dec_states = self.model.decoder(decoder_input, src_features, dec_states, step=step)
|
||||
|
||||
# Generator forward.
|
||||
log_probs = self.generator.forward(dec_out.transpose(0, 1).squeeze(0))
|
||||
@@ -978,10 +875,7 @@ class Translator(object):
|
||||
words = " ".join(words).replace(" ##", "").split()
|
||||
if len(words) <= 3:
|
||||
continue
|
||||
trigrams = [
|
||||
(words[i - 1], words[i], words[i + 1])
|
||||
for i in range(1, len(words) - 1)
|
||||
]
|
||||
trigrams = [(words[i - 1], words[i], words[i + 1]) for i in range(1, len(words) - 1)]
|
||||
trigram = tuple(trigrams[-1])
|
||||
if trigram in trigrams[:-1]:
|
||||
fail = True
|
||||
@@ -999,15 +893,11 @@ class Translator(object):
|
||||
topk_ids = topk_ids.fmod(vocab_size)
|
||||
|
||||
# Map beam_index to batch_index in the flat representation.
|
||||
batch_index = topk_beam_index + beam_offset[
|
||||
: topk_beam_index.size(0)
|
||||
].unsqueeze(1)
|
||||
batch_index = topk_beam_index + beam_offset[: topk_beam_index.size(0)].unsqueeze(1)
|
||||
select_indices = batch_index.view(-1)
|
||||
|
||||
# Append last prediction.
|
||||
alive_seq = torch.cat(
|
||||
[alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1)], -1
|
||||
)
|
||||
alive_seq = torch.cat([alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1)], -1)
|
||||
|
||||
is_finished = topk_ids.eq(self.end_token)
|
||||
if step + 1 == max_length:
|
||||
@@ -1040,15 +930,11 @@ class Translator(object):
|
||||
topk_log_probs = topk_log_probs.index_select(0, non_finished)
|
||||
batch_index = batch_index.index_select(0, non_finished)
|
||||
batch_offset = batch_offset.index_select(0, non_finished)
|
||||
alive_seq = predictions.index_select(0, non_finished).view(
|
||||
-1, alive_seq.size(-1)
|
||||
)
|
||||
alive_seq = predictions.index_select(0, non_finished).view(-1, alive_seq.size(-1))
|
||||
# Reorder states.
|
||||
select_indices = batch_index.view(-1)
|
||||
src_features = src_features.index_select(0, select_indices)
|
||||
dec_states.map_batch_fn(
|
||||
lambda state, dim: state.index_select(dim, select_indices)
|
||||
)
|
||||
dec_states.map_batch_fn(lambda state, dim: state.index_select(dim, select_indices))
|
||||
|
||||
return results
|
||||
|
||||
@@ -1089,14 +975,7 @@ def tile(x, count, dim=0):
|
||||
out_size = list(x.size())
|
||||
out_size[0] *= count
|
||||
batch = x.size(0)
|
||||
x = (
|
||||
x.view(batch, -1)
|
||||
.transpose(0, 1)
|
||||
.repeat(count, 1)
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
.view(*out_size)
|
||||
)
|
||||
x = x.view(batch, -1).transpose(0, 1).repeat(count, 1).transpose(0, 1).contiguous().view(*out_size)
|
||||
if dim != 0:
|
||||
x = x.permute(perm).contiguous()
|
||||
return x
|
||||
@@ -1107,6 +986,7 @@ def tile(x, count, dim=0):
|
||||
# a finetuning script.
|
||||
#
|
||||
|
||||
|
||||
class BertSumOptimizer(object):
|
||||
""" Specific optimizer for BertSum.
|
||||
|
||||
@@ -1126,16 +1006,10 @@ class BertSumOptimizer(object):
|
||||
|
||||
self.optimizers = {
|
||||
"encoder": torch.optim.Adam(
|
||||
model.encoder.parameters(),
|
||||
lr=lr["encoder"],
|
||||
betas=(beta_1, beta_2),
|
||||
eps=eps,
|
||||
model.encoder.parameters(), lr=lr["encoder"], betas=(beta_1, beta_2), eps=eps,
|
||||
),
|
||||
"decoder": torch.optim.Adam(
|
||||
model.decoder.parameters(),
|
||||
lr=lr["decoder"],
|
||||
betas=(beta_1, beta_2),
|
||||
eps=eps,
|
||||
model.decoder.parameters(), lr=lr["decoder"], betas=(beta_1, beta_2), eps=eps,
|
||||
),
|
||||
}
|
||||
|
||||
@@ -1143,9 +1017,7 @@ class BertSumOptimizer(object):
|
||||
self.current_learning_rates = {}
|
||||
|
||||
def _update_rate(self, stack):
|
||||
return self.lr[stack] * min(
|
||||
self._step ** (-0.5), self._step * self.warmup_steps[stack] ** (-1.5)
|
||||
)
|
||||
return self.lr[stack] * min(self._step ** (-0.5), self._step * self.warmup_steps[stack] ** (-1.5))
|
||||
|
||||
def zero_grad(self):
|
||||
self.optimizer_decoder.zero_grad()
|
||||
|
||||
Reference in New Issue
Block a user