Black 20 release
This commit is contained in:
@@ -28,7 +28,7 @@ BERTABS_FINETUNED_CONFIG_MAP = {
|
||||
|
||||
|
||||
class BertAbsConfig(PretrainedConfig):
|
||||
r""" Class to store the configuration of the BertAbs model.
|
||||
r"""Class to store the configuration of the BertAbs model.
|
||||
|
||||
Arguments:
|
||||
vocab_size: int
|
||||
|
||||
@@ -62,7 +62,7 @@ BertAbsConfig = namedtuple(
|
||||
|
||||
|
||||
def convert_bertabs_checkpoints(path_to_checkpoints, dump_path):
|
||||
""" Copy/paste and tweak the pre-trained weights provided by the creators
|
||||
"""Copy/paste and tweak the pre-trained weights provided by the creators
|
||||
of BertAbs for the internal architecture.
|
||||
"""
|
||||
|
||||
@@ -164,13 +164,22 @@ def convert_bertabs_checkpoints(path_to_checkpoints, dump_path):
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--bertabs_checkpoint_path", default=None, type=str, required=True, help="Path the official PyTorch dump.",
|
||||
"--bertabs_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path the official PyTorch dump.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model.",
|
||||
"--pytorch_dump_folder_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output PyTorch model.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_bertabs_checkpoints(
|
||||
args.bertabs_checkpoint_path, args.pytorch_dump_folder_path,
|
||||
args.bertabs_checkpoint_path,
|
||||
args.pytorch_dump_folder_path,
|
||||
)
|
||||
|
||||
@@ -105,10 +105,17 @@ class BertAbs(BertAbsPreTrainedModel):
|
||||
p.data.zero_()
|
||||
|
||||
def forward(
|
||||
self, encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask,
|
||||
self,
|
||||
encoder_input_ids,
|
||||
decoder_input_ids,
|
||||
token_type_ids,
|
||||
encoder_attention_mask,
|
||||
decoder_attention_mask,
|
||||
):
|
||||
encoder_output = self.bert(
|
||||
input_ids=encoder_input_ids, token_type_ids=token_type_ids, attention_mask=encoder_attention_mask,
|
||||
input_ids=encoder_input_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=encoder_attention_mask,
|
||||
)
|
||||
encoder_hidden_states = encoder_output[0]
|
||||
dec_state = self.decoder.init_decoder_state(encoder_input_ids, encoder_hidden_states)
|
||||
@@ -117,8 +124,7 @@ class BertAbs(BertAbsPreTrainedModel):
|
||||
|
||||
|
||||
class Bert(nn.Module):
|
||||
""" This class is not really necessary and should probably disappear.
|
||||
"""
|
||||
"""This class is not really necessary and should probably disappear."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -307,7 +313,14 @@ class TransformerDecoderLayer(nn.Module):
|
||||
self.register_buffer("mask", mask)
|
||||
|
||||
def forward(
|
||||
self, inputs, memory_bank, src_pad_mask, tgt_pad_mask, previous_input=None, layer_cache=None, step=None,
|
||||
self,
|
||||
inputs,
|
||||
memory_bank,
|
||||
src_pad_mask,
|
||||
tgt_pad_mask,
|
||||
previous_input=None,
|
||||
layer_cache=None,
|
||||
step=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -331,13 +344,25 @@ class TransformerDecoderLayer(nn.Module):
|
||||
all_input = torch.cat((previous_input, input_norm), dim=1)
|
||||
dec_mask = None
|
||||
|
||||
query = self.self_attn(all_input, all_input, input_norm, mask=dec_mask, layer_cache=layer_cache, type="self",)
|
||||
query = self.self_attn(
|
||||
all_input,
|
||||
all_input,
|
||||
input_norm,
|
||||
mask=dec_mask,
|
||||
layer_cache=layer_cache,
|
||||
type="self",
|
||||
)
|
||||
|
||||
query = self.drop(query) + inputs
|
||||
|
||||
query_norm = self.layer_norm_2(query)
|
||||
mid = self.context_attn(
|
||||
memory_bank, memory_bank, query_norm, mask=src_pad_mask, layer_cache=layer_cache, type="context",
|
||||
memory_bank,
|
||||
memory_bank,
|
||||
query_norm,
|
||||
mask=src_pad_mask,
|
||||
layer_cache=layer_cache,
|
||||
type="context",
|
||||
)
|
||||
output = self.feed_forward(self.drop(mid) + query)
|
||||
|
||||
@@ -422,7 +447,14 @@ class MultiHeadedAttention(nn.Module):
|
||||
self.final_linear = nn.Linear(model_dim, model_dim)
|
||||
|
||||
def forward(
|
||||
self, key, value, query, mask=None, layer_cache=None, type=None, predefined_graph_1=None,
|
||||
self,
|
||||
key,
|
||||
value,
|
||||
query,
|
||||
mask=None,
|
||||
layer_cache=None,
|
||||
type=None,
|
||||
predefined_graph_1=None,
|
||||
):
|
||||
"""
|
||||
Compute the context vector and the attention vectors.
|
||||
@@ -628,7 +660,7 @@ def gelu(x):
|
||||
|
||||
|
||||
class PositionwiseFeedForward(nn.Module):
|
||||
""" A two-layer Feed-Forward-Network with residual layer norm.
|
||||
"""A two-layer Feed-Forward-Network with residual layer norm.
|
||||
|
||||
Args:
|
||||
d_model (int): the size of input for the first-layer of the FFN.
|
||||
@@ -770,8 +802,7 @@ class Translator(object):
|
||||
self.max_length = args.max_length
|
||||
|
||||
def translate(self, batch, step, attn_debug=False):
|
||||
""" Generates summaries from one batch of data.
|
||||
"""
|
||||
"""Generates summaries from one batch of data."""
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
batch_data = self.translate_batch(batch)
|
||||
@@ -798,8 +829,7 @@ class Translator(object):
|
||||
# Where the beam search lives
|
||||
# I have no idea why it is being called from the method above
|
||||
def _fast_translate_batch(self, batch, max_length, min_length=0):
|
||||
""" Beam Search using the encoder inputs contained in `batch`.
|
||||
"""
|
||||
"""Beam Search using the encoder inputs contained in `batch`."""
|
||||
|
||||
# The batch object is funny
|
||||
# Instead of just looking at the size of the arguments we encapsulate
|
||||
@@ -981,7 +1011,7 @@ def tile(x, count, dim=0):
|
||||
|
||||
|
||||
class BertSumOptimizer(object):
|
||||
""" Specific optimizer for BertSum.
|
||||
"""Specific optimizer for BertSum.
|
||||
|
||||
As described in [1], the authors fine-tune BertSum for abstractive
|
||||
summarization using two Adam Optimizers with different warm-up steps and
|
||||
@@ -999,10 +1029,16 @@ class BertSumOptimizer(object):
|
||||
|
||||
self.optimizers = {
|
||||
"encoder": torch.optim.Adam(
|
||||
model.encoder.parameters(), lr=lr["encoder"], betas=(beta_1, beta_2), eps=eps,
|
||||
model.encoder.parameters(),
|
||||
lr=lr["encoder"],
|
||||
betas=(beta_1, beta_2),
|
||||
eps=eps,
|
||||
),
|
||||
"decoder": torch.optim.Adam(
|
||||
model.decoder.parameters(), lr=lr["decoder"], betas=(beta_1, beta_2), eps=eps,
|
||||
model.decoder.parameters(),
|
||||
lr=lr["decoder"],
|
||||
betas=(beta_1, beta_2),
|
||||
eps=eps,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -99,7 +99,7 @@ def evaluate(args):
|
||||
|
||||
|
||||
def save_summaries(summaries, path, original_document_name):
|
||||
""" Write the summaries in fies that are prefixed by the original
|
||||
"""Write the summaries in fies that are prefixed by the original
|
||||
files' name with the `_summary` appended.
|
||||
|
||||
Attributes:
|
||||
@@ -125,7 +125,7 @@ def save_summaries(summaries, path, original_document_name):
|
||||
|
||||
|
||||
def format_summary(translation):
|
||||
""" Transforms the output of the `from_batch` function
|
||||
"""Transforms the output of the `from_batch` function
|
||||
into nicely formatted summaries.
|
||||
"""
|
||||
raw_summary, _, _ = translation
|
||||
@@ -190,7 +190,12 @@ def build_data_iterator(args, tokenizer):
|
||||
def collate_fn(data):
|
||||
return collate(data, tokenizer, block_size=512, device=args.device)
|
||||
|
||||
iterator = DataLoader(dataset, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_fn,)
|
||||
iterator = DataLoader(
|
||||
dataset,
|
||||
sampler=sampler,
|
||||
batch_size=args.batch_size,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
return iterator
|
||||
|
||||
@@ -201,7 +206,7 @@ def load_and_cache_examples(args, tokenizer):
|
||||
|
||||
|
||||
def collate(data, tokenizer, block_size, device):
|
||||
""" Collate formats the data passed to the data loader.
|
||||
"""Collate formats the data passed to the data loader.
|
||||
|
||||
In particular we tokenize the data batch after batch to avoid keeping them
|
||||
all in memory. We output the data as a namedtuple to fit the original BertAbs's
|
||||
@@ -231,7 +236,7 @@ def collate(data, tokenizer, block_size, device):
|
||||
|
||||
|
||||
def decode_summary(summary_tokens, tokenizer):
|
||||
""" Decode the summary and return it in a format
|
||||
"""Decode the summary and return it in a format
|
||||
suitable for evaluation.
|
||||
"""
|
||||
summary_tokens = summary_tokens.to("cpu").numpy()
|
||||
@@ -242,8 +247,7 @@ def decode_summary(summary_tokens, tokenizer):
|
||||
|
||||
|
||||
def main():
|
||||
""" The main function defines the interface with the users.
|
||||
"""
|
||||
"""The main function defines the interface with the users."""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--documents_dir",
|
||||
@@ -268,23 +272,41 @@ def main():
|
||||
)
|
||||
# EVALUATION options
|
||||
parser.add_argument(
|
||||
"--no_cuda", default=False, type=bool, help="Whether to force the execution on CPU.",
|
||||
"--no_cuda",
|
||||
default=False,
|
||||
type=bool,
|
||||
help="Whether to force the execution on CPU.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.",
|
||||
"--batch_size",
|
||||
default=4,
|
||||
type=int,
|
||||
help="Batch size per GPU/CPU for training.",
|
||||
)
|
||||
# BEAM SEARCH arguments
|
||||
parser.add_argument(
|
||||
"--min_length", default=50, type=int, help="Minimum number of tokens for the summaries.",
|
||||
"--min_length",
|
||||
default=50,
|
||||
type=int,
|
||||
help="Minimum number of tokens for the summaries.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_length", default=200, type=int, help="Maixmum number of tokens for the summaries.",
|
||||
"--max_length",
|
||||
default=200,
|
||||
type=int,
|
||||
help="Maixmum number of tokens for the summaries.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--beam_size", default=5, type=int, help="The number of beams to start with for each example.",
|
||||
"--beam_size",
|
||||
default=5,
|
||||
type=int,
|
||||
help="The number of beams to start with for each example.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha", default=0.95, type=float, help="The value of alpha for the length penalty in the beam search.",
|
||||
"--alpha",
|
||||
default=0.95,
|
||||
type=float,
|
||||
help="The value of alpha for the length penalty in the beam search.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--block_trigram",
|
||||
|
||||
@@ -43,8 +43,7 @@ class SummarizationDataProcessingTest(unittest.TestCase):
|
||||
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
|
||||
|
||||
def test_process_story_no_highlights(self):
|
||||
""" Processing a story with no highlights returns an empty list for the summary.
|
||||
"""
|
||||
"""Processing a story with no highlights returns an empty list for the summary."""
|
||||
raw_story = """It was the year of Our Lord one thousand seven hundred and
|
||||
seventy-five.\n\nSpiritual revelations were conceded to England at that
|
||||
favoured period, as at this."""
|
||||
@@ -52,8 +51,7 @@ class SummarizationDataProcessingTest(unittest.TestCase):
|
||||
self.assertEqual(summary_lines, [])
|
||||
|
||||
def test_process_empty_story(self):
|
||||
""" An empty story returns an empty collection of lines.
|
||||
"""
|
||||
"""An empty story returns an empty collection of lines."""
|
||||
raw_story = ""
|
||||
story_lines, summary_lines = process_story(raw_story)
|
||||
self.assertEqual(story_lines, [])
|
||||
|
||||
@@ -11,7 +11,7 @@ from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class CNNDMDataset(Dataset):
|
||||
""" Abstracts the dataset used to train seq2seq models.
|
||||
"""Abstracts the dataset used to train seq2seq models.
|
||||
|
||||
The class will process the documents that are located in the specified
|
||||
folder. The preprocessing will work on any document that is reasonably
|
||||
@@ -31,7 +31,7 @@ class CNNDMDataset(Dataset):
|
||||
"""
|
||||
|
||||
def __init__(self, path="", prefix="train"):
|
||||
""" We initialize the class by listing all the documents to summarize.
|
||||
"""We initialize the class by listing all the documents to summarize.
|
||||
Files are not read in memory due to the size of some datasets (like CNN/DailyMail).
|
||||
"""
|
||||
assert os.path.isdir(path)
|
||||
@@ -60,7 +60,7 @@ class CNNDMDataset(Dataset):
|
||||
|
||||
|
||||
def process_story(raw_story):
|
||||
""" Extract the story and summary from a story file.
|
||||
"""Extract the story and summary from a story file.
|
||||
|
||||
Arguments:
|
||||
raw_story (str): content of the story file as an utf-8 encoded string.
|
||||
@@ -108,7 +108,7 @@ def _add_missing_period(line):
|
||||
|
||||
|
||||
def truncate_or_pad(sequence, block_size, pad_token_id):
|
||||
""" Adapt the source and target sequences' lengths to the block size.
|
||||
"""Adapt the source and target sequences' lengths to the block size.
|
||||
If the sequence is shorter we append padding token to the right of the sequence.
|
||||
"""
|
||||
if len(sequence) > block_size:
|
||||
@@ -119,8 +119,8 @@ def truncate_or_pad(sequence, block_size, pad_token_id):
|
||||
|
||||
|
||||
def build_mask(sequence, pad_token_id):
|
||||
""" Builds the mask. The attention mechanism will only attend to positions
|
||||
with value 1. """
|
||||
"""Builds the mask. The attention mechanism will only attend to positions
|
||||
with value 1."""
|
||||
mask = torch.ones_like(sequence)
|
||||
idx_pad_tokens = sequence == pad_token_id
|
||||
mask[idx_pad_tokens] = 0
|
||||
@@ -128,7 +128,7 @@ def build_mask(sequence, pad_token_id):
|
||||
|
||||
|
||||
def encode_for_summarization(story_lines, summary_lines, tokenizer):
|
||||
""" Encode the story and summary lines, and join them
|
||||
"""Encode the story and summary lines, and join them
|
||||
as specified in [1] by using `[SEP] [CLS]` tokens to separate
|
||||
sentences.
|
||||
"""
|
||||
@@ -141,7 +141,7 @@ def encode_for_summarization(story_lines, summary_lines, tokenizer):
|
||||
|
||||
|
||||
def compute_token_type_ids(batch, separator_token_id):
|
||||
""" Segment embeddings as described in [1]
|
||||
"""Segment embeddings as described in [1]
|
||||
|
||||
The values {0,1} were found in the repository [2].
|
||||
|
||||
|
||||
@@ -97,4 +97,9 @@ def get_checkpoint_callback(output_dir, metric):
|
||||
|
||||
|
||||
def get_early_stopping_callback(metric, patience):
|
||||
return EarlyStopping(monitor=f"val_{metric}", mode="max", patience=patience, verbose=True,)
|
||||
return EarlyStopping(
|
||||
monitor=f"val_{metric}",
|
||||
mode="max",
|
||||
patience=patience,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
@@ -348,7 +348,10 @@ class T5SummarizationDistiller(BartSummarizationDistiller):
|
||||
if self.different_encoder:
|
||||
with torch.no_grad():
|
||||
teacher_enc_outputs, teacher_enc_hid = self.teacher.encoder(
|
||||
source_ids, attention_mask=source_mask, output_hidden_states=True, use_cache=False,
|
||||
source_ids,
|
||||
attention_mask=source_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
if self.hparams.alpha_encoder_loss > 0:
|
||||
loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, source_mask)
|
||||
|
||||
@@ -117,7 +117,12 @@ class TestSummarizationDistiller(unittest.TestCase):
|
||||
|
||||
@require_multigpu
|
||||
def test_multigpu(self):
|
||||
updates = dict(no_teacher=True, freeze_encoder=True, gpus=2, sortish_sampler=False,)
|
||||
updates = dict(
|
||||
no_teacher=True,
|
||||
freeze_encoder=True,
|
||||
gpus=2,
|
||||
sortish_sampler=False,
|
||||
)
|
||||
self._test_distiller_cli(updates)
|
||||
|
||||
def test_distill_no_teacher(self):
|
||||
@@ -261,7 +266,8 @@ def test_run_eval_bart(model):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)],
|
||||
["model"],
|
||||
[pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)],
|
||||
)
|
||||
def test_finetune(model):
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
@@ -329,7 +335,8 @@ def test_finetune_extra_model_args():
|
||||
output_dir = tempfile.mkdtemp(prefix="output_1_")
|
||||
args_d1 = args_d.copy()
|
||||
args_d1.update(
|
||||
model_name_or_path=model, output_dir=output_dir,
|
||||
model_name_or_path=model,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout")
|
||||
for p in extra_model_params:
|
||||
@@ -344,7 +351,8 @@ def test_finetune_extra_model_args():
|
||||
output_dir = tempfile.mkdtemp(prefix="output_2_")
|
||||
args_d2 = args_d.copy()
|
||||
args_d2.update(
|
||||
model_name_or_path=model, output_dir=output_dir,
|
||||
model_name_or_path=model,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
unsupported_param = "encoder_layerdrop"
|
||||
args_d2[unsupported_param] = 0.5
|
||||
@@ -478,7 +486,11 @@ def test_summarization_dataset_truncation(tok):
|
||||
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
|
||||
trunc_target = 4
|
||||
train_dataset = Seq2SeqDataset(
|
||||
tokenizer, data_dir=tmp_dir, type_path="train", max_source_length=20, max_target_length=trunc_target,
|
||||
tokenizer,
|
||||
data_dir=tmp_dir,
|
||||
type_path="train",
|
||||
max_source_length=20,
|
||||
max_target_length=trunc_target,
|
||||
)
|
||||
dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=train_dataset.collate_fn)
|
||||
for batch in dataloader:
|
||||
|
||||
@@ -63,7 +63,9 @@ def calculate_bleu(output_lns, refs_lns, **kwargs) -> dict:
|
||||
|
||||
|
||||
def trim_batch(
|
||||
input_ids, pad_token_id, attention_mask=None,
|
||||
input_ids,
|
||||
pad_token_id,
|
||||
attention_mask=None,
|
||||
):
|
||||
"""Remove columns that are populated exclusively by pad_token_id"""
|
||||
keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
|
||||
|
||||
Reference in New Issue
Block a user