[Longformer] Fix longformer documentation (#7016)
* fix longformer * allow position ids to not be initialized
This commit is contained in:
committed by
GitHub
parent
5c4eb4b1ac
commit
120176ea29
@@ -795,6 +795,7 @@ class LongformerPreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
config_class = LongformerConfig
|
config_class = LongformerConfig
|
||||||
base_model_prefix = "longformer"
|
base_model_prefix = "longformer"
|
||||||
|
authorized_missing_keys = [r"position_ids"]
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
""" Initialize the weights """
|
""" Initialize the weights """
|
||||||
@@ -1019,11 +1020,13 @@ class LongformerModel(LongformerPreTrainedModel):
|
|||||||
|
|
||||||
>>> # Attention mask values -- 0: no attention, 1: local attention, 2: global attention
|
>>> # Attention mask values -- 0: no attention, 1: local attention, 2: global attention
|
||||||
>>> attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) # initialize to local attention
|
>>> attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) # initialize to local attention
|
||||||
>>> attention_mask[:, [1, 4, 21,]] = 2 # Set global attention based on the task. For example,
|
>>> global_attention_mask = torch.zeros(input_ids.shape, dtype=torch.long, device=input_ids.device) # initialize to global attention to be deactivated for all tokens
|
||||||
|
>>> global_attention_mask[:, [1, 4, 21,]] = 1 # Set global attention to random tokens for the sake of this example
|
||||||
|
... # Usually, set global attention based on the task. For example,
|
||||||
... # classification: the <s> token
|
... # classification: the <s> token
|
||||||
... # QA: question tokens
|
... # QA: question tokens
|
||||||
... # LM: potentially on the beginning of sentences and paragraphs
|
... # LM: potentially on the beginning of sentences and paragraphs
|
||||||
>>> outputs = model(input_ids, attention_mask=attention_mask)
|
>>> outputs = model(input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask)
|
||||||
>>> sequence_output = outputs.last_hidden_state
|
>>> sequence_output = outputs.last_hidden_state
|
||||||
>>> pooled_output = outputs.pooler_output
|
>>> pooled_output = outputs.pooler_output
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user