Reformat source code with black.
This is the result of:
$ black --line-length 119 examples templates transformers utils hubconf.py setup.py
There's a lot of fairly long lines in the project. As a consequence, I'm
picking the longest widely accepted line length, 119 characters.
This is also Thomas' preference, because it allows for explicit variable
names, to make the code easier to understand.
This commit is contained in:
@@ -34,12 +34,30 @@ logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SAMPLE_TEXT = 'Hello world! cécé herlolip'
|
||||
SAMPLE_TEXT = "Hello world! cécé herlolip"
|
||||
|
||||
|
||||
BertAbsConfig = namedtuple(
|
||||
"BertAbsConfig",
|
||||
["temp_dir", "large", "use_bert_emb", "finetune_bert", "encoder", "share_emb", "max_pos", "enc_layers", "enc_hidden_size", "enc_heads", "enc_ff_size", "enc_dropout", "dec_layers", "dec_hidden_size", "dec_heads", "dec_ff_size", "dec_dropout"],
|
||||
[
|
||||
"temp_dir",
|
||||
"large",
|
||||
"use_bert_emb",
|
||||
"finetune_bert",
|
||||
"encoder",
|
||||
"share_emb",
|
||||
"max_pos",
|
||||
"enc_layers",
|
||||
"enc_hidden_size",
|
||||
"enc_heads",
|
||||
"enc_ff_size",
|
||||
"enc_dropout",
|
||||
"dec_layers",
|
||||
"dec_hidden_size",
|
||||
"dec_heads",
|
||||
"dec_ff_size",
|
||||
"dec_dropout",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -119,7 +137,9 @@ def convert_bertabs_checkpoints(path_to_checkpoints, dump_path):
|
||||
output_original_model = original(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)[0]
|
||||
output_original_generator = original.generator(output_original_model)
|
||||
|
||||
output_converted_model = new_model(encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask)[0]
|
||||
output_converted_model = new_model(
|
||||
encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask
|
||||
)[0]
|
||||
output_converted_generator = new_model.generator(output_converted_model)
|
||||
|
||||
maximum_absolute_difference = torch.max(torch.abs(output_converted_model - output_original_model)).item()
|
||||
@@ -136,28 +156,21 @@ def convert_bertabs_checkpoints(path_to_checkpoints, dump_path):
|
||||
# The model has been saved with torch.save(model) and this is bound to the exact
|
||||
# directory structure. We save the state_dict instead.
|
||||
logging.info("saving the model's state dictionary")
|
||||
torch.save(new_model.state_dict(), "bertabs-finetuned-cnndm-extractive-abstractive-summarization-pytorch_model.bin")
|
||||
torch.save(
|
||||
new_model.state_dict(), "bertabs-finetuned-cnndm-extractive-abstractive-summarization-pytorch_model.bin"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--bertabs_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path the official PyTorch dump.",
|
||||
"--bertabs_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output PyTorch model.",
|
||||
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_bertabs_checkpoints(
|
||||
args.bertabs_checkpoint_path,
|
||||
args.pytorch_dump_folder_path,
|
||||
args.bertabs_checkpoint_path, args.pytorch_dump_folder_path,
|
||||
)
|
||||
|
||||
@@ -56,40 +56,22 @@ class BertAbs(BertAbsPreTrainedModel):
|
||||
load_bert_pretrained_extractive = True if bert_extractive_checkpoint else False
|
||||
if load_bert_pretrained_extractive:
|
||||
self.bert.model.load_state_dict(
|
||||
dict(
|
||||
[
|
||||
(n[11:], p)
|
||||
for n, p in bert_extractive_checkpoint.items()
|
||||
if n.startswith("bert.model")
|
||||
]
|
||||
),
|
||||
dict([(n[11:], p) for n, p in bert_extractive_checkpoint.items() if n.startswith("bert.model")]),
|
||||
strict=True,
|
||||
)
|
||||
|
||||
self.vocab_size = self.bert.model.config.vocab_size
|
||||
|
||||
if args.max_pos > 512:
|
||||
my_pos_embeddings = nn.Embedding(
|
||||
args.max_pos, self.bert.model.config.hidden_size
|
||||
)
|
||||
my_pos_embeddings.weight.data[
|
||||
:512
|
||||
] = self.bert.model.embeddings.position_embeddings.weight.data
|
||||
my_pos_embeddings.weight.data[
|
||||
512:
|
||||
] = self.bert.model.embeddings.position_embeddings.weight.data[-1][
|
||||
my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size)
|
||||
my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data
|
||||
my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][
|
||||
None, :
|
||||
].repeat(
|
||||
args.max_pos - 512, 1
|
||||
)
|
||||
].repeat(args.max_pos - 512, 1)
|
||||
self.bert.model.embeddings.position_embeddings = my_pos_embeddings
|
||||
tgt_embeddings = nn.Embedding(
|
||||
self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0
|
||||
)
|
||||
tgt_embeddings = nn.Embedding(self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0)
|
||||
|
||||
tgt_embeddings.weight = copy.deepcopy(
|
||||
self.bert.model.embeddings.word_embeddings.weight
|
||||
)
|
||||
tgt_embeddings.weight = copy.deepcopy(self.bert.model.embeddings.word_embeddings.weight)
|
||||
|
||||
self.decoder = TransformerDecoder(
|
||||
self.args.dec_layers,
|
||||
@@ -102,9 +84,7 @@ class BertAbs(BertAbsPreTrainedModel):
|
||||
)
|
||||
|
||||
gen_func = nn.LogSoftmax(dim=-1)
|
||||
self.generator = nn.Sequential(
|
||||
nn.Linear(args.dec_hidden_size, args.vocab_size), gen_func
|
||||
)
|
||||
self.generator = nn.Sequential(nn.Linear(args.dec_hidden_size, args.vocab_size), gen_func)
|
||||
self.generator[0].weight = self.decoder.embeddings.weight
|
||||
|
||||
load_from_checkpoints = False if checkpoint is None else True
|
||||
@@ -127,25 +107,14 @@ class BertAbs(BertAbsPreTrainedModel):
|
||||
p.data.zero_()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_input_ids,
|
||||
decoder_input_ids,
|
||||
token_type_ids,
|
||||
encoder_attention_mask,
|
||||
decoder_attention_mask,
|
||||
self, encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask,
|
||||
):
|
||||
encoder_output = self.bert(
|
||||
input_ids=encoder_input_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=encoder_attention_mask,
|
||||
input_ids=encoder_input_ids, token_type_ids=token_type_ids, attention_mask=encoder_attention_mask,
|
||||
)
|
||||
encoder_hidden_states = encoder_output[0]
|
||||
dec_state = self.decoder.init_decoder_state(
|
||||
encoder_input_ids, encoder_hidden_states
|
||||
)
|
||||
decoder_outputs, _ = self.decoder(
|
||||
decoder_input_ids[:, :-1], encoder_hidden_states, dec_state
|
||||
)
|
||||
dec_state = self.decoder.init_decoder_state(encoder_input_ids, encoder_hidden_states)
|
||||
decoder_outputs, _ = self.decoder(decoder_input_ids[:, :-1], encoder_hidden_states, dec_state)
|
||||
return decoder_outputs
|
||||
|
||||
|
||||
@@ -162,10 +131,7 @@ class Bert(nn.Module):
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
encoder_outputs, _ = self.model(
|
||||
input_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs
|
||||
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, **kwargs
|
||||
)
|
||||
return encoder_outputs
|
||||
|
||||
@@ -196,10 +162,7 @@ class TransformerDecoder(nn.Module):
|
||||
|
||||
# Build TransformerDecoder.
|
||||
self.transformer_layers = nn.ModuleList(
|
||||
[
|
||||
TransformerDecoderLayer(d_model, heads, d_ff, dropout)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
[TransformerDecoderLayer(d_model, heads, d_ff, dropout) for _ in range(num_layers)]
|
||||
)
|
||||
|
||||
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
||||
@@ -236,20 +199,14 @@ class TransformerDecoder(nn.Module):
|
||||
# Decoder padding mask
|
||||
tgt_words = tgt
|
||||
tgt_batch, tgt_len = tgt_words.size()
|
||||
tgt_pad_mask = (
|
||||
tgt_words.data.eq(padding_idx).unsqueeze(1).expand(tgt_batch, tgt_len, tgt_len)
|
||||
)
|
||||
tgt_pad_mask = tgt_words.data.eq(padding_idx).unsqueeze(1).expand(tgt_batch, tgt_len, tgt_len)
|
||||
|
||||
# Encoder padding mask
|
||||
if memory_mask is not None:
|
||||
src_len = memory_mask.size(-1)
|
||||
src_pad_mask = memory_mask.expand(src_batch, tgt_len, src_len)
|
||||
else:
|
||||
src_pad_mask = (
|
||||
src_words.data.eq(padding_idx)
|
||||
.unsqueeze(1)
|
||||
.expand(src_batch, tgt_len, src_len)
|
||||
)
|
||||
src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1).expand(src_batch, tgt_len, src_len)
|
||||
|
||||
# Pass through the embeddings
|
||||
emb = self.embeddings(input_ids)
|
||||
@@ -271,9 +228,7 @@ class TransformerDecoder(nn.Module):
|
||||
src_pad_mask,
|
||||
tgt_pad_mask,
|
||||
previous_input=prev_layer_input,
|
||||
layer_cache=state.cache["layer_{}".format(i)]
|
||||
if state.cache is not None
|
||||
else None,
|
||||
layer_cache=state.cache["layer_{}".format(i)] if state.cache is not None else None,
|
||||
step=step,
|
||||
)
|
||||
if state.cache is None:
|
||||
@@ -303,9 +258,7 @@ class PositionalEncoding(nn.Module):
|
||||
def __init__(self, dropout, dim, max_len=5000):
|
||||
pe = torch.zeros(max_len, dim)
|
||||
position = torch.arange(0, max_len).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
(torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim))
|
||||
)
|
||||
div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim)))
|
||||
pe[:, 0::2] = torch.sin(position.float() * div_term)
|
||||
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
@@ -356,14 +309,7 @@ class TransformerDecoderLayer(nn.Module):
|
||||
self.register_buffer("mask", mask)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs,
|
||||
memory_bank,
|
||||
src_pad_mask,
|
||||
tgt_pad_mask,
|
||||
previous_input=None,
|
||||
layer_cache=None,
|
||||
step=None,
|
||||
self, inputs, memory_bank, src_pad_mask, tgt_pad_mask, previous_input=None, layer_cache=None, step=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -380,34 +326,20 @@ class TransformerDecoderLayer(nn.Module):
|
||||
* all_input `[batch_size x current_step x model_dim]`
|
||||
|
||||
"""
|
||||
dec_mask = torch.gt(
|
||||
tgt_pad_mask + self.mask[:, : tgt_pad_mask.size(1), : tgt_pad_mask.size(1)], 0
|
||||
)
|
||||
dec_mask = torch.gt(tgt_pad_mask + self.mask[:, : tgt_pad_mask.size(1), : tgt_pad_mask.size(1)], 0)
|
||||
input_norm = self.layer_norm_1(inputs)
|
||||
all_input = input_norm
|
||||
if previous_input is not None:
|
||||
all_input = torch.cat((previous_input, input_norm), dim=1)
|
||||
dec_mask = None
|
||||
|
||||
query = self.self_attn(
|
||||
all_input,
|
||||
all_input,
|
||||
input_norm,
|
||||
mask=dec_mask,
|
||||
layer_cache=layer_cache,
|
||||
type="self",
|
||||
)
|
||||
query = self.self_attn(all_input, all_input, input_norm, mask=dec_mask, layer_cache=layer_cache, type="self",)
|
||||
|
||||
query = self.drop(query) + inputs
|
||||
|
||||
query_norm = self.layer_norm_2(query)
|
||||
mid = self.context_attn(
|
||||
memory_bank,
|
||||
memory_bank,
|
||||
query_norm,
|
||||
mask=src_pad_mask,
|
||||
layer_cache=layer_cache,
|
||||
type="context",
|
||||
memory_bank, memory_bank, query_norm, mask=src_pad_mask, layer_cache=layer_cache, type="context",
|
||||
)
|
||||
output = self.feed_forward(self.drop(mid) + query)
|
||||
|
||||
@@ -492,14 +424,7 @@ class MultiHeadedAttention(nn.Module):
|
||||
self.final_linear = nn.Linear(model_dim, model_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
key,
|
||||
value,
|
||||
query,
|
||||
mask=None,
|
||||
layer_cache=None,
|
||||
type=None,
|
||||
predefined_graph_1=None,
|
||||
self, key, value, query, mask=None, layer_cache=None, type=None, predefined_graph_1=None,
|
||||
):
|
||||
"""
|
||||
Compute the context vector and the attention vectors.
|
||||
@@ -531,11 +456,7 @@ class MultiHeadedAttention(nn.Module):
|
||||
|
||||
def unshape(x):
|
||||
""" compute context """
|
||||
return (
|
||||
x.transpose(1, 2)
|
||||
.contiguous()
|
||||
.view(batch_size, -1, head_count * dim_per_head)
|
||||
)
|
||||
return x.transpose(1, 2).contiguous().view(batch_size, -1, head_count * dim_per_head)
|
||||
|
||||
# 1) Project key, value, and query.
|
||||
if layer_cache is not None:
|
||||
@@ -554,9 +475,7 @@ class MultiHeadedAttention(nn.Module):
|
||||
if layer_cache["self_keys"] is not None:
|
||||
key = torch.cat((layer_cache["self_keys"].to(device), key), dim=2)
|
||||
if layer_cache["self_values"] is not None:
|
||||
value = torch.cat(
|
||||
(layer_cache["self_values"].to(device), value), dim=2
|
||||
)
|
||||
value = torch.cat((layer_cache["self_values"].to(device), value), dim=2)
|
||||
layer_cache["self_keys"] = key
|
||||
layer_cache["self_values"] = value
|
||||
elif type == "context":
|
||||
@@ -637,13 +556,9 @@ class DecoderState(object):
|
||||
sizes = e.size()
|
||||
br = sizes[1]
|
||||
if len(sizes) == 3:
|
||||
sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2])[
|
||||
:, :, idx
|
||||
]
|
||||
sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2])[:, :, idx]
|
||||
else:
|
||||
sent_states = e.view(
|
||||
sizes[0], beam_size, br // beam_size, sizes[2], sizes[3]
|
||||
)[:, :, idx]
|
||||
sent_states = e.view(sizes[0], beam_size, br // beam_size, sizes[2], sizes[3])[:, :, idx]
|
||||
|
||||
sent_states.data.copy_(sent_states.data.index_select(1, positions))
|
||||
|
||||
@@ -716,11 +631,7 @@ class TransformerDecoderState(DecoderState):
|
||||
|
||||
|
||||
def gelu(x):
|
||||
return (
|
||||
0.5
|
||||
* x
|
||||
* (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
)
|
||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
|
||||
class PositionwiseFeedForward(nn.Module):
|
||||
@@ -758,9 +669,7 @@ class PositionwiseFeedForward(nn.Module):
|
||||
def build_predictor(args, tokenizer, symbols, model, logger=None):
|
||||
# we should be able to refactor the global scorer a lot
|
||||
scorer = GNMTGlobalScorer(args.alpha, length_penalty="wu")
|
||||
translator = Translator(
|
||||
args, model, tokenizer, symbols, global_scorer=scorer, logger=logger
|
||||
)
|
||||
translator = Translator(args, model, tokenizer, symbols, global_scorer=scorer, logger=logger)
|
||||
return translator
|
||||
|
||||
|
||||
@@ -891,9 +800,7 @@ class Translator(object):
|
||||
Shouldn't need the original dataset.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
return self._fast_translate_batch(
|
||||
batch, self.max_length, min_length=self.min_length
|
||||
)
|
||||
return self._fast_translate_batch(batch, self.max_length, min_length=self.min_length)
|
||||
|
||||
# Where the beam search lives
|
||||
# I have no idea why it is being called from the method above
|
||||
@@ -912,26 +819,18 @@ class Translator(object):
|
||||
mask_src = batch.mask_src
|
||||
|
||||
src_features = self.model.bert(src, segs, mask_src)
|
||||
dec_states = self.model.decoder.init_decoder_state(
|
||||
src, src_features, with_cache=True
|
||||
)
|
||||
dec_states = self.model.decoder.init_decoder_state(src, src_features, with_cache=True)
|
||||
device = src_features.device
|
||||
|
||||
# Tile states and memory beam_size times.
|
||||
dec_states.map_batch_fn(lambda state, dim: tile(state, beam_size, dim=dim))
|
||||
src_features = tile(src_features, beam_size, dim=0)
|
||||
batch_offset = torch.arange(batch_size, dtype=torch.long, device=device)
|
||||
beam_offset = torch.arange(
|
||||
0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=device
|
||||
)
|
||||
alive_seq = torch.full(
|
||||
[batch_size * beam_size, 1], self.start_token, dtype=torch.long, device=device
|
||||
)
|
||||
beam_offset = torch.arange(0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=device)
|
||||
alive_seq = torch.full([batch_size * beam_size, 1], self.start_token, dtype=torch.long, device=device)
|
||||
|
||||
# Give full probability to the first beam on the first step.
|
||||
topk_log_probs = torch.tensor(
|
||||
[0.0] + [float("-inf")] * (beam_size - 1), device=device
|
||||
).repeat(batch_size)
|
||||
topk_log_probs = torch.tensor([0.0] + [float("-inf")] * (beam_size - 1), device=device).repeat(batch_size)
|
||||
|
||||
# Structure that holds finished hypotheses.
|
||||
hypotheses = [[] for _ in range(batch_size)] # noqa: F812
|
||||
@@ -948,9 +847,7 @@ class Translator(object):
|
||||
# Decoder forward.
|
||||
decoder_input = decoder_input.transpose(0, 1)
|
||||
|
||||
dec_out, dec_states = self.model.decoder(
|
||||
decoder_input, src_features, dec_states, step=step
|
||||
)
|
||||
dec_out, dec_states = self.model.decoder(decoder_input, src_features, dec_states, step=step)
|
||||
|
||||
# Generator forward.
|
||||
log_probs = self.generator.forward(dec_out.transpose(0, 1).squeeze(0))
|
||||
@@ -978,10 +875,7 @@ class Translator(object):
|
||||
words = " ".join(words).replace(" ##", "").split()
|
||||
if len(words) <= 3:
|
||||
continue
|
||||
trigrams = [
|
||||
(words[i - 1], words[i], words[i + 1])
|
||||
for i in range(1, len(words) - 1)
|
||||
]
|
||||
trigrams = [(words[i - 1], words[i], words[i + 1]) for i in range(1, len(words) - 1)]
|
||||
trigram = tuple(trigrams[-1])
|
||||
if trigram in trigrams[:-1]:
|
||||
fail = True
|
||||
@@ -999,15 +893,11 @@ class Translator(object):
|
||||
topk_ids = topk_ids.fmod(vocab_size)
|
||||
|
||||
# Map beam_index to batch_index in the flat representation.
|
||||
batch_index = topk_beam_index + beam_offset[
|
||||
: topk_beam_index.size(0)
|
||||
].unsqueeze(1)
|
||||
batch_index = topk_beam_index + beam_offset[: topk_beam_index.size(0)].unsqueeze(1)
|
||||
select_indices = batch_index.view(-1)
|
||||
|
||||
# Append last prediction.
|
||||
alive_seq = torch.cat(
|
||||
[alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1)], -1
|
||||
)
|
||||
alive_seq = torch.cat([alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1)], -1)
|
||||
|
||||
is_finished = topk_ids.eq(self.end_token)
|
||||
if step + 1 == max_length:
|
||||
@@ -1040,15 +930,11 @@ class Translator(object):
|
||||
topk_log_probs = topk_log_probs.index_select(0, non_finished)
|
||||
batch_index = batch_index.index_select(0, non_finished)
|
||||
batch_offset = batch_offset.index_select(0, non_finished)
|
||||
alive_seq = predictions.index_select(0, non_finished).view(
|
||||
-1, alive_seq.size(-1)
|
||||
)
|
||||
alive_seq = predictions.index_select(0, non_finished).view(-1, alive_seq.size(-1))
|
||||
# Reorder states.
|
||||
select_indices = batch_index.view(-1)
|
||||
src_features = src_features.index_select(0, select_indices)
|
||||
dec_states.map_batch_fn(
|
||||
lambda state, dim: state.index_select(dim, select_indices)
|
||||
)
|
||||
dec_states.map_batch_fn(lambda state, dim: state.index_select(dim, select_indices))
|
||||
|
||||
return results
|
||||
|
||||
@@ -1089,14 +975,7 @@ def tile(x, count, dim=0):
|
||||
out_size = list(x.size())
|
||||
out_size[0] *= count
|
||||
batch = x.size(0)
|
||||
x = (
|
||||
x.view(batch, -1)
|
||||
.transpose(0, 1)
|
||||
.repeat(count, 1)
|
||||
.transpose(0, 1)
|
||||
.contiguous()
|
||||
.view(*out_size)
|
||||
)
|
||||
x = x.view(batch, -1).transpose(0, 1).repeat(count, 1).transpose(0, 1).contiguous().view(*out_size)
|
||||
if dim != 0:
|
||||
x = x.permute(perm).contiguous()
|
||||
return x
|
||||
@@ -1107,6 +986,7 @@ def tile(x, count, dim=0):
|
||||
# a finetuning script.
|
||||
#
|
||||
|
||||
|
||||
class BertSumOptimizer(object):
|
||||
""" Specific optimizer for BertSum.
|
||||
|
||||
@@ -1126,16 +1006,10 @@ class BertSumOptimizer(object):
|
||||
|
||||
self.optimizers = {
|
||||
"encoder": torch.optim.Adam(
|
||||
model.encoder.parameters(),
|
||||
lr=lr["encoder"],
|
||||
betas=(beta_1, beta_2),
|
||||
eps=eps,
|
||||
model.encoder.parameters(), lr=lr["encoder"], betas=(beta_1, beta_2), eps=eps,
|
||||
),
|
||||
"decoder": torch.optim.Adam(
|
||||
model.decoder.parameters(),
|
||||
lr=lr["decoder"],
|
||||
betas=(beta_1, beta_2),
|
||||
eps=eps,
|
||||
model.decoder.parameters(), lr=lr["decoder"], betas=(beta_1, beta_2), eps=eps,
|
||||
),
|
||||
}
|
||||
|
||||
@@ -1143,9 +1017,7 @@ class BertSumOptimizer(object):
|
||||
self.current_learning_rates = {}
|
||||
|
||||
def _update_rate(self, stack):
|
||||
return self.lr[stack] * min(
|
||||
self._step ** (-0.5), self._step * self.warmup_steps[stack] ** (-1.5)
|
||||
)
|
||||
return self.lr[stack] * min(self._step ** (-0.5), self._step * self.warmup_steps[stack] ** (-1.5))
|
||||
|
||||
def zero_grad(self):
|
||||
self.optimizer_decoder.zero_grad()
|
||||
|
||||
@@ -25,9 +25,7 @@ logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
|
||||
|
||||
Batch = namedtuple(
|
||||
"Batch", ["document_names", "batch_size", "src", "segs", "mask_src", "tgt_str"]
|
||||
)
|
||||
Batch = namedtuple("Batch", ["document_names", "batch_size", "src", "segs", "mask_src", "tgt_str"])
|
||||
|
||||
|
||||
def evaluate(args):
|
||||
@@ -48,13 +46,14 @@ def evaluate(args):
|
||||
|
||||
import rouge
|
||||
import nltk
|
||||
nltk.download('punkt')
|
||||
|
||||
nltk.download("punkt")
|
||||
rouge_evaluator = rouge.Rouge(
|
||||
metrics=['rouge-n', 'rouge-l'],
|
||||
metrics=["rouge-n", "rouge-l"],
|
||||
max_n=2,
|
||||
limit_length=True,
|
||||
length_limit=args.beam_size,
|
||||
length_limit_type='words',
|
||||
length_limit_type="words",
|
||||
apply_avg=True,
|
||||
apply_best=False,
|
||||
alpha=0.5, # Default F1_score
|
||||
@@ -161,15 +160,15 @@ Recall >> {:.3f}
|
||||
F1 >> {:.3f}
|
||||
Precision >> {:.3f}
|
||||
Recall >> {:.3f}""".format(
|
||||
scores['rouge-1']['f'],
|
||||
scores['rouge-1']['p'],
|
||||
scores['rouge-1']['r'],
|
||||
scores['rouge-2']['f'],
|
||||
scores['rouge-2']['p'],
|
||||
scores['rouge-2']['r'],
|
||||
scores['rouge-l']['f'],
|
||||
scores['rouge-l']['p'],
|
||||
scores['rouge-l']['r'],
|
||||
scores["rouge-1"]["f"],
|
||||
scores["rouge-1"]["p"],
|
||||
scores["rouge-1"]["r"],
|
||||
scores["rouge-2"]["f"],
|
||||
scores["rouge-2"]["p"],
|
||||
scores["rouge-2"]["r"],
|
||||
scores["rouge-l"]["f"],
|
||||
scores["rouge-l"]["p"],
|
||||
scores["rouge-l"]["r"],
|
||||
)
|
||||
|
||||
|
||||
@@ -187,9 +186,7 @@ def build_data_iterator(args, tokenizer):
|
||||
dataset = load_and_cache_examples(args, tokenizer)
|
||||
sampler = SequentialSampler(dataset)
|
||||
collate_fn = lambda data: collate(data, tokenizer, block_size=512, device=args.device)
|
||||
iterator = DataLoader(
|
||||
dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_fn,
|
||||
)
|
||||
iterator = DataLoader(dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_fn,)
|
||||
|
||||
return iterator
|
||||
|
||||
@@ -210,14 +207,9 @@ def collate(data, tokenizer, block_size, device):
|
||||
names = [name for name, _, _ in data]
|
||||
summaries = [" ".join(summary_list) for _, _, summary_list in data]
|
||||
|
||||
encoded_text = [
|
||||
encode_for_summarization(story, summary, tokenizer) for _, story, summary in data
|
||||
]
|
||||
encoded_text = [encode_for_summarization(story, summary, tokenizer) for _, story, summary in data]
|
||||
encoded_stories = torch.tensor(
|
||||
[
|
||||
fit_to_block_size(story, block_size, tokenizer.pad_token_id)
|
||||
for story, _ in encoded_text
|
||||
]
|
||||
[fit_to_block_size(story, block_size, tokenizer.pad_token_id) for story, _ in encoded_text]
|
||||
)
|
||||
encoder_token_type_ids = compute_token_type_ids(encoded_stories, tokenizer.cls_token_id)
|
||||
encoder_mask = build_mask(encoded_stories, tokenizer.pad_token_id)
|
||||
@@ -272,38 +264,23 @@ def main():
|
||||
)
|
||||
# EVALUATION options
|
||||
parser.add_argument(
|
||||
"--no_cuda",
|
||||
default=False,
|
||||
type=bool,
|
||||
help="Whether to force the execution on CPU.",
|
||||
"--no_cuda", default=False, type=bool, help="Whether to force the execution on CPU.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.",
|
||||
)
|
||||
# BEAM SEARCH arguments
|
||||
parser.add_argument(
|
||||
"--min_length",
|
||||
default=50,
|
||||
type=int,
|
||||
help="Minimum number of tokens for the summaries.",
|
||||
"--min_length", default=50, type=int, help="Minimum number of tokens for the summaries.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_length",
|
||||
default=200,
|
||||
type=int,
|
||||
help="Maixmum number of tokens for the summaries.",
|
||||
"--max_length", default=200, type=int, help="Maixmum number of tokens for the summaries.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--beam_size",
|
||||
default=5,
|
||||
type=int,
|
||||
help="The number of beams to start with for each example.",
|
||||
"--beam_size", default=5, type=int, help="The number of beams to start with for each example.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha",
|
||||
default=0.95,
|
||||
type=float,
|
||||
help="The value of alpha for the length penalty in the beam search.",
|
||||
"--alpha", default=0.95, type=float, help="The value of alpha for the length penalty in the beam search.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--block_trigram",
|
||||
|
||||
@@ -68,9 +68,7 @@ def process_story(raw_story):
|
||||
Raises:
|
||||
IndexError: If the stoy is empty or contains no highlights.
|
||||
"""
|
||||
nonempty_lines = list(
|
||||
filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
|
||||
)
|
||||
nonempty_lines = list(filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]))
|
||||
|
||||
# for some unknown reason some lines miss a period, add it
|
||||
nonempty_lines = [_add_missing_period(line) for line in nonempty_lines]
|
||||
@@ -135,13 +133,9 @@ def encode_for_summarization(story_lines, summary_lines, tokenizer):
|
||||
sentences.
|
||||
"""
|
||||
story_lines_token_ids = [tokenizer.encode(line) for line in story_lines]
|
||||
story_token_ids = [
|
||||
token for sentence in story_lines_token_ids for token in sentence
|
||||
]
|
||||
story_token_ids = [token for sentence in story_lines_token_ids for token in sentence]
|
||||
summary_lines_token_ids = [tokenizer.encode(line) for line in summary_lines]
|
||||
summary_token_ids = [
|
||||
token for sentence in summary_lines_token_ids for token in sentence
|
||||
]
|
||||
summary_token_ids = [token for sentence in summary_lines_token_ids for token in sentence]
|
||||
|
||||
return story_token_ids, summary_token_ids
|
||||
|
||||
|
||||
@@ -33,25 +33,19 @@ class SummarizationDataProcessingTest(unittest.TestCase):
|
||||
""" Pad the sequence with 0 if the sequence is smaller than the block size."""
|
||||
sequence = [1, 2, 3, 4]
|
||||
expected_output = [1, 2, 3, 4, 0, 0, 0, 0, 0, 0]
|
||||
self.assertEqual(
|
||||
fit_to_block_size(sequence, self.block_size, 0), expected_output
|
||||
)
|
||||
self.assertEqual(fit_to_block_size(sequence, self.block_size, 0), expected_output)
|
||||
|
||||
def test_fit_to_block_sequence_fit_exactly(self):
|
||||
""" Do nothing if the sequence is the right size. """
|
||||
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
self.assertEqual(
|
||||
fit_to_block_size(sequence, self.block_size, 0), expected_output
|
||||
)
|
||||
self.assertEqual(fit_to_block_size(sequence, self.block_size, 0), expected_output)
|
||||
|
||||
def test_fit_to_block_sequence_too_big(self):
|
||||
""" Truncate the sequence if it is too long. """
|
||||
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
|
||||
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
self.assertEqual(
|
||||
fit_to_block_size(sequence, self.block_size, 0), expected_output
|
||||
)
|
||||
self.assertEqual(fit_to_block_size(sequence, self.block_size, 0), expected_output)
|
||||
|
||||
def test_process_story_no_highlights(self):
|
||||
""" Processing a story with no highlights returns an empty list for the summary.
|
||||
@@ -95,9 +89,7 @@ class SummarizationDataProcessingTest(unittest.TestCase):
|
||||
def test_build_mask(self):
|
||||
sequence = torch.tensor([1, 2, 3, 4, 23, 23, 23])
|
||||
expected = torch.tensor([1, 1, 1, 1, 0, 0, 0])
|
||||
np.testing.assert_array_equal(
|
||||
build_mask(sequence, 23).numpy(), expected.numpy()
|
||||
)
|
||||
np.testing.assert_array_equal(build_mask(sequence, 23).numpy(), expected.numpy())
|
||||
|
||||
def test_build_mask_with_padding_equal_to_one(self):
|
||||
sequence = torch.tensor([8, 2, 3, 4, 1, 1, 1])
|
||||
@@ -106,12 +98,8 @@ class SummarizationDataProcessingTest(unittest.TestCase):
|
||||
|
||||
def test_compute_token_type_ids(self):
|
||||
separator = 101
|
||||
batch = torch.tensor(
|
||||
[[1, 2, 3, 4, 5, 6], [1, 2, 3, 101, 5, 6], [1, 101, 3, 4, 101, 6]]
|
||||
)
|
||||
expected = torch.tensor(
|
||||
[[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 1, 1]]
|
||||
)
|
||||
batch = torch.tensor([[1, 2, 3, 4, 5, 6], [1, 2, 3, 101, 5, 6], [1, 101, 3, 4, 101, 6]])
|
||||
expected = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 1, 1]])
|
||||
|
||||
result = compute_token_type_ids(batch, separator)
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
|
||||
Reference in New Issue
Block a user