From c589eae2b83be5206dab7a899738a0995624cc82 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 26 May 2020 14:58:47 +0200 Subject: [PATCH] [Longformer For Question Answering] Conversion script, doc, small fixes (#4593) * add new longformer for question answering model * add new config as well * fix links * fix links part 2 --- docs/source/model_doc/longformer.rst | 7 ++ src/transformers/configuration_longformer.py | 1 + ...r_original_pytorch_lightning_to_pytorch.py | 86 +++++++++++++++++++ src/transformers/modeling_longformer.py | 47 +++++----- src/transformers/tokenization_longformer.py | 3 +- 5 files changed, 123 insertions(+), 21 deletions(-) create mode 100644 src/transformers/convert_longformer_original_pytorch_lightning_to_pytorch.py diff --git a/docs/source/model_doc/longformer.rst b/docs/source/model_doc/longformer.rst index 9f12a80f08..7e8e816410 100644 --- a/docs/source/model_doc/longformer.rst +++ b/docs/source/model_doc/longformer.rst @@ -67,3 +67,10 @@ LongformerForMaskedLM .. autoclass:: transformers.LongformerForMaskedLM :members: + + +LongformerForQuestionAnswering +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.LongformerForQuestionAnswering + :members: diff --git a/src/transformers/configuration_longformer.py b/src/transformers/configuration_longformer.py index dedafac943..559cc4a3f2 100644 --- a/src/transformers/configuration_longformer.py +++ b/src/transformers/configuration_longformer.py @@ -25,6 +25,7 @@ logger = logging.getLogger(__name__) LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = { "longformer-base-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096/config.json", "longformer-large-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096/config.json", + "longformer-large-4096-finetuned-triviaqa": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096-finetuned-triviaqa/config.json", } diff --git a/src/transformers/convert_longformer_original_pytorch_lightning_to_pytorch.py b/src/transformers/convert_longformer_original_pytorch_lightning_to_pytorch.py new file mode 100644 index 0000000000..248f2d1ed9 --- /dev/null +++ b/src/transformers/convert_longformer_original_pytorch_lightning_to_pytorch.py @@ -0,0 +1,86 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert RoBERTa checkpoint.""" + + +import argparse + +import pytorch_lightning as pl +import torch + +from transformers.modeling_longformer import LongformerForQuestionAnswering, LongformerModel + + +class LightningModel(pl.LightningModule): + def __init__(self, model): + super().__init__() + self.model = model + self.num_labels = 2 + self.qa_outputs = torch.nn.Linear(self.model.config.hidden_size, self.num_labels) + + # implement only because lighning requires to do so + def forward(self): + pass + + +def convert_longformer_qa_checkpoint_to_pytorch( + longformer_model: str, longformer_question_answering_ckpt_path: str, pytorch_dump_folder_path: str +): + + # load longformer model from model identifier + longformer = LongformerModel.from_pretrained(longformer_model) + lightning_model = LightningModel(longformer) + + ckpt = torch.load(longformer_question_answering_ckpt_path, map_location=torch.device("cpu")) + lightning_model.load_state_dict(ckpt["state_dict"]) + + # init longformer question answering model + longformer_for_qa = LongformerForQuestionAnswering.from_pretrained(longformer_model) + + # transfer weights + longformer_for_qa.longformer.load_state_dict(lightning_model.model.state_dict()) + longformer_for_qa.qa_outputs.load_state_dict(lightning_model.qa_outputs.state_dict()) + longformer_for_qa.eval() + + # save model + longformer_for_qa.save_pretrained(pytorch_dump_folder_path) + + print("Conversion succesful. Model saved under {}".format(pytorch_dump_folder_path)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--longformer_model", + default=None, + type=str, + required=True, + help="model identifier of longformer. Should be either `longformer-base-4096` or `longformer-large-4096`.", + ) + parser.add_argument( + "--longformer_question_answering_ckpt_path", + default=None, + type=str, + required=True, + help="Path the official PyTorch Lighning Checkpoint.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_longformer_qa_checkpoint_to_pytorch( + args.longformer_model, args.longformer_question_answering_ckpt_path, args.pytorch_dump_folder_path + ) diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index 8ff2534bcb..3570fe5957 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -33,6 +33,7 @@ logger = logging.getLogger(__name__) LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP = { "longformer-base-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096/pytorch_model.bin", "longformer-large-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096/pytorch_model.bin", + "longformer-large-4096-finetuned-triviaqa": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096-finetuned-triviaqa/pytorch_model.bin", } @@ -710,7 +711,7 @@ class LongformerForMaskedLM(BertPreTrainedModel): @add_start_docstrings( - """Longformer Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of + """Longformer Model with a span classification head on top for extractive question-answering tasks like SQuAD / TriviaQA (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """, LONGFORMER_START_DOCSTRING, ) @@ -728,26 +729,27 @@ class LongformerForQuestionAnswering(BertPreTrainedModel): self.init_weights() - def _get_question_end_index(self, input_ids): - sep_token_indices = (input_ids == self.config.sep_token_id).nonzero() - - assert sep_token_indices.size(1) == 2, "input_ids should have two dimensions" - assert sep_token_indices.size(0) == 3 * input_ids.size( - 0 - ), "There should be exactly three separator tokens in every sample for questions answering" - - return sep_token_indices.view(input_ids.size(0), 3, 2)[:, 0, 1] - def _compute_global_attention_mask(self, input_ids): question_end_index = self._get_question_end_index(input_ids) question_end_index = question_end_index.unsqueeze(dim=1) # size: batch_size x 1 # bool attention mask with True in locations of global attention - attention_mask = torch.arange(input_ids.size(1), device=input_ids.device) + attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device) attention_mask = attention_mask.expand_as(input_ids) < question_end_index - attention_mask = attention_mask.int() + 1 # from True, False to 2, 1 + attention_mask = attention_mask.int() + 1 # True => global attention; False => local attention return attention_mask.long() + def _get_question_end_index(self, input_ids): + sep_token_indices = (input_ids == self.config.sep_token_id).nonzero() + batch_size = input_ids.shape[0] + + assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions" + assert ( + sep_token_indices.shape[0] == 3 * batch_size + ), f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering" + + return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1] + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING) def forward( self, @@ -769,7 +771,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel): Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. Returns: - :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.LongformerConfig`) and inputs: loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`): @@ -785,24 +787,29 @@ class LongformerForQuestionAnswering(BertPreTrainedModel): :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + Examples:: + from transformers import LongformerTokenizer, LongformerForQuestionAnswering import torch - tokenizer = LongformerTokenizer.from_pretrained(longformer-base-4096') - model = LongformerForQuestionAnswering.from_pretrained(longformer-base-4096') + tokenizer = LongformerTokenizer.from_pretrained("longformer-large-4096-finetuned-triviaqa") + model = LongformerForQuestionAnswering.from_pretrained("longformer-large-4096-finetuned-triviaqa") question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" - encoding = tokenizer.encode_plus(question, text) + encoding = tokenizer.encode_plus(question, text, return_tensors="pt") input_ids = encoding["input_ids"] # default is local attention everywhere # the forward method will automatically set global attention on question tokens attention_mask = encoding["attention_mask"] - start_scores, end_scores = model(torch.tensor([input_ids]), attention_mask=attention_mask) - all_tokens = tokenizer.convert_ids_to_tokens(input_ids) - answer = ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]) + start_scores, end_scores = model(input_ids, attention_mask=attention_mask) + all_tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist()) + + answer_tokens = all_tokens[torch.argmax(start_scores) :torch.argmax(end_scores)+1] + answer = tokenizer.decode(tokenizer.convert_tokens_to_ids(answer_tokens)) # remove space prepending space token + """ # set global attention on question tokens diff --git a/src/transformers/tokenization_longformer.py b/src/transformers/tokenization_longformer.py index 7ac2a00901..c6986220f9 100644 --- a/src/transformers/tokenization_longformer.py +++ b/src/transformers/tokenization_longformer.py @@ -24,12 +24,13 @@ logger = logging.getLogger(__name__) # vocab and merges same as roberta vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json" merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt" -_all_longformer_models = ["longformer-base-4096", "longformer-large-4096"] +_all_longformer_models = ["longformer-base-4096", "longformer-large-4096", "longformer-large-4096-finetuned-triviaqa"] PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { "longformer-base-4096": 4096, "longformer-large-4096": 4096, + "longformer-large-4096-finetuned-triviaqa": 4096, }