[Longformer] more models + model cards (#4628)
* adding freeze roberta models * model cards * lint
This commit is contained in:
@@ -23,9 +23,11 @@ from .configuration_roberta import RobertaConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"longformer-base-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096/config.json",
|
||||
"longformer-large-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096/config.json",
|
||||
"longformer-large-4096-finetuned-triviaqa": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096-finetuned-triviaqa/config.json",
|
||||
"allenai/longformer-base-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096/config.json",
|
||||
"allenai/longformer-large-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096/config.json",
|
||||
"allenai/longformer-large-4096-finetuned-triviaqa": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096-finetuned-triviaqa/config.json",
|
||||
"allenai/longformer-base-4096-extra.pos.embd.only": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096-extra.pos.embd.only/config.json",
|
||||
"allenai/longformer-large-4096-extra.pos.embd.only": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096-extra.pos.embd.only/config.json",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -31,9 +31,11 @@ from .modeling_roberta import RobertaLMHead, RobertaModel
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"longformer-base-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096/pytorch_model.bin",
|
||||
"longformer-large-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096/pytorch_model.bin",
|
||||
"longformer-large-4096-finetuned-triviaqa": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096-finetuned-triviaqa/pytorch_model.bin",
|
||||
"allenai/longformer-base-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096/pytorch_model.bin",
|
||||
"allenai/longformer-large-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096/pytorch_model.bin",
|
||||
"allenai/longformer-large-4096-finetuned-triviaqa": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096-finetuned-triviaqa/pytorch_model.bin",
|
||||
"allenai/longformer-base-4096-extra.pos.embd.only": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096-extra.pos.embd.only/pytorch_model.bin",
|
||||
"allenai/longformer-large-4096-extra.pos.embd.only": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096-extra.pos.embd.only/pytorch_model.bin",
|
||||
}
|
||||
|
||||
|
||||
@@ -851,8 +853,7 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
|
||||
attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device)
|
||||
attention_mask = attention_mask.expand_as(input_ids) < question_end_index
|
||||
|
||||
attention_mask = attention_mask.int() + 1 # True => global attention; False => local attention
|
||||
return attention_mask.long()
|
||||
return attention_mask.long() + 1 # True => global attention; False => local attention
|
||||
|
||||
def _get_question_end_index(self, input_ids):
|
||||
sep_token_indices = (input_ids == self.config.sep_token_id).nonzero()
|
||||
|
||||
@@ -24,13 +24,21 @@ logger = logging.getLogger(__name__)
|
||||
# vocab and merges same as roberta
|
||||
vocab_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json"
|
||||
merges_url = "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt"
|
||||
_all_longformer_models = ["longformer-base-4096", "longformer-large-4096", "longformer-large-4096-finetuned-triviaqa"]
|
||||
_all_longformer_models = [
|
||||
"allenai/longformer-base-4096",
|
||||
"allenai/longformer-large-4096",
|
||||
"allenai/longformer-large-4096-finetuned-triviaqa",
|
||||
"allenai/longformer-base-4096-extra.pos.embd.only",
|
||||
"allenai/longformer-large-4096-extra.pos.embd.only",
|
||||
]
|
||||
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"longformer-base-4096": 4096,
|
||||
"longformer-large-4096": 4096,
|
||||
"longformer-large-4096-finetuned-triviaqa": 4096,
|
||||
"allenai/longformer-base-4096": 4096,
|
||||
"allenai/longformer-large-4096": 4096,
|
||||
"allenai/longformer-large-4096-finetuned-triviaqa": 4096,
|
||||
"allenai/longformer-base-4096-extra.pos.embd.only": 4096,
|
||||
"allenai/longformer-large-4096-extra.pos.embd.only": 4096,
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user