[Longformer] more models + model cards (#4628)

* adding freeze roberta models

* model cards

* lint
This commit is contained in:
Iz Beltagy
2020-05-28 02:11:05 -07:00
committed by GitHub
parent 96f57c9ccb
commit ef03ae874f
5 changed files with 67 additions and 12 deletions

View File

@@ -23,9 +23,11 @@ from .configuration_roberta import RobertaConfig
logger = logging.getLogger(__name__)
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"longformer-base-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096/config.json",
"longformer-large-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096/config.json",
"longformer-large-4096-finetuned-triviaqa": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096-finetuned-triviaqa/config.json",
"allenai/longformer-base-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096/config.json",
"allenai/longformer-large-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096/config.json",
"allenai/longformer-large-4096-finetuned-triviaqa": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096-finetuned-triviaqa/config.json",
"allenai/longformer-base-4096-extra.pos.embd.only": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096-extra.pos.embd.only/config.json",
"allenai/longformer-large-4096-extra.pos.embd.only": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096-extra.pos.embd.only/config.json",
}

View File

@@ -31,9 +31,11 @@ from .modeling_roberta import RobertaLMHead, RobertaModel
logger = logging.getLogger(__name__)
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP = {
"longformer-base-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096/pytorch_model.bin",
"longformer-large-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096/pytorch_model.bin",
"longformer-large-4096-finetuned-triviaqa": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096-finetuned-triviaqa/pytorch_model.bin",
"allenai/longformer-base-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096/pytorch_model.bin",
"allenai/longformer-large-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096/pytorch_model.bin",
"allenai/longformer-large-4096-finetuned-triviaqa": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096-finetuned-triviaqa/pytorch_model.bin",
"allenai/longformer-base-4096-extra.pos.embd.only": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096-extra.pos.embd.only/pytorch_model.bin",
"allenai/longformer-large-4096-extra.pos.embd.only": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096-extra.pos.embd.only/pytorch_model.bin",
}
@@ -851,8 +853,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device)
attention_mask = attention_mask.expand_as(input_ids) < question_end_index
attention_mask = attention_mask.int() + 1 # True => global attention; False => local attention
return attention_mask.long()
return attention_mask.long() + 1 # True => global attention; False => local attention
def _get_question_end_index(self, input_ids):
sep_token_indices = (input_ids == self.config.sep_token_id).nonzero()

View File

@@ -24,13 +24,21 @@ logger = logging.getLogger(__name__)
# vocab and merges same as roberta
vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json"
merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt"
_all_longformer_models = ["longformer-base-4096", "longformer-large-4096", "longformer-large-4096-finetuned-triviaqa"]
_all_longformer_models = [
"allenai/longformer-base-4096",
"allenai/longformer-large-4096",
"allenai/longformer-large-4096-finetuned-triviaqa",
"allenai/longformer-base-4096-extra.pos.embd.only",
"allenai/longformer-large-4096-extra.pos.embd.only",
]
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"longformer-base-4096": 4096,
"longformer-large-4096": 4096,
"longformer-large-4096-finetuned-triviaqa": 4096,
"allenai/longformer-base-4096": 4096,
"allenai/longformer-large-4096": 4096,
"allenai/longformer-large-4096-finetuned-triviaqa": 4096,
"allenai/longformer-base-4096-extra.pos.embd.only": 4096,
"allenai/longformer-large-4096-extra.pos.embd.only": 4096,
}