Doctest longformer (#16441)
* Add initial doctring changes * make fixup * Add TF doc changes * fix seq classifier output * fix quality errors * t * swithc head to random init * Fix expected outputs * Update src/transformers/models/longformer/modeling_longformer.py Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -1782,23 +1782,31 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
|
||||
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
Mask filling example:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import LongformerForMaskedLM, LongformerTokenizer
|
||||
>>> from transformers import LongformerTokenizer, LongformerForMaskedLM
|
||||
|
||||
>>> model = LongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096")
|
||||
>>> tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
|
||||
>>> model = LongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096")
|
||||
```
|
||||
|
||||
>>> SAMPLE_TEXT = " ".join(["Hello world! "] * 1000) # long input document
|
||||
>>> input_ids = torch.tensor(tokenizer.encode(SAMPLE_TEXT)).unsqueeze(0) # batch of size 1
|
||||
Let's try a very long input.
|
||||
|
||||
>>> attention_mask = None # default is local attention everywhere, which is a good choice for MaskedLM
|
||||
>>> # check `LongformerModel.forward` for more details how to set *attention_mask*
|
||||
>>> outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
|
||||
>>> loss = outputs.loss
|
||||
>>> prediction_logits = outputs.logits
|
||||
```python
|
||||
>>> TXT = (
|
||||
... "My friends are <mask> but they eat too many carbs."
|
||||
... + " That's why I decide not to eat with them." * 300
|
||||
... )
|
||||
>>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
|
||||
>>> logits = model(input_ids).logits
|
||||
|
||||
>>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
|
||||
>>> probs = logits[0, masked_index].softmax(dim=0)
|
||||
>>> values, predictions = probs.topk(5)
|
||||
|
||||
>>> tokenizer.decode(predictions).split()
|
||||
['healthy', 'skinny', 'thin', 'good', 'vegetarian']
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
@@ -1860,9 +1868,11 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
|
||||
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
checkpoint="jpelhaw/longformer-base-plagiarism-detection",
|
||||
output_type=LongformerSequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output="'ORIGINAL'",
|
||||
expected_loss=5.44,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
@@ -2127,9 +2137,11 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
|
||||
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
checkpoint="brad1141/Longformer-finetuned-norm",
|
||||
output_type=LongformerTokenClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output="['Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence', 'Evidence']",
|
||||
expected_loss=0.63,
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -2102,10 +2102,12 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
|
||||
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
checkpoint="allenai/longformer-base-4096",
|
||||
output_type=TFLongformerMaskedLMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
mask="<mask>",
|
||||
expected_output="' Paris'",
|
||||
expected_loss=0.44,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
@@ -2198,6 +2200,8 @@ class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAn
|
||||
checkpoint="allenai/longformer-large-4096-finetuned-triviaqa",
|
||||
output_type=TFLongformerQuestionAnsweringModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output="' puppet'",
|
||||
expected_loss=0.96,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
@@ -2344,9 +2348,11 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
|
||||
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
checkpoint="hf-internal-testing/tiny-random-longformer",
|
||||
output_type=TFLongformerSequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output="'LABEL_1'",
|
||||
expected_loss=0.69,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
@@ -2582,9 +2588,11 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
|
||||
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
checkpoint="hf-internal-testing/tiny-random-longformer",
|
||||
output_type=TFLongformerTokenClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
expected_output="['LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1']",
|
||||
expected_loss=0.59,
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
|
||||
@@ -31,6 +31,8 @@ src/transformers/models/glpn/modeling_glpn.py
|
||||
src/transformers/models/gpt2/modeling_gpt2.py
|
||||
src/transformers/models/gptj/modeling_gptj.py
|
||||
src/transformers/models/hubert/modeling_hubert.py
|
||||
src/transformers/models/longformer/modeling_longformer.py
|
||||
src/transformers/models/longformer/modeling_tf_longformer.py
|
||||
src/transformers/models/marian/modeling_marian.py
|
||||
src/transformers/models/mbart/modeling_mbart.py
|
||||
src/transformers/models/mobilebert/modeling_mobilebert.py
|
||||
|
||||
Reference in New Issue
Block a user