[LED] fix global_attention_mask not being passed for generation and docs clarification about grad checkpointing (#17112)

* [LED] fixed global_attention_mask not passed for generation + docs clarification for gradient checkpointing

* LED docs clarification

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* [LED] gradient_checkpointing=True should be passed to TrainingArguments

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* [LED] docs: remove wrong word

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* [LED] docs fix typo

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Cesare Campagnano
2022-05-17 23:44:37 +02:00
committed by GitHub
parent bad358398a
commit d9050dc768
2 changed files with 6 additions and 2 deletions

View File

@@ -44,8 +44,10 @@ Tips:
- LED makes use of *global attention* by means of the `global_attention_mask` (see - LED makes use of *global attention* by means of the `global_attention_mask` (see
[`LongformerModel`]). For summarization, it is advised to put *global attention* only on the first [`LongformerModel`]). For summarization, it is advised to put *global attention* only on the first
`<s>` token. For question answering, it is advised to put *global attention* on all tokens of the question. `<s>` token. For question answering, it is advised to put *global attention* on all tokens of the question.
- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by executing - To fine-tune LED on all 16384, *gradient checkpointing* can be enabled in case training leads to out-of-memory (OOM)
`model.gradient_checkpointing_enable()`. errors. This can be done by executing `model.gradient_checkpointing_enable()`.
Moreover, the `use_cache=False`
flag can be used to disable the caching mechanism to save memory.
- A notebook showing how to evaluate LED, can be accessed [here](https://colab.research.google.com/drive/12INTTR6n64TzS4RrXZxMSXfrOd9Xzamo?usp=sharing). - A notebook showing how to evaluate LED, can be accessed [here](https://colab.research.google.com/drive/12INTTR6n64TzS4RrXZxMSXfrOd9Xzamo?usp=sharing).
- A notebook showing how to fine-tune LED, can be accessed [here](https://colab.research.google.com/drive/12LjJazBl7Gam0XBPy_y0CTOJZeZ34c2v?usp=sharing). - A notebook showing how to fine-tune LED, can be accessed [here](https://colab.research.google.com/drive/12LjJazBl7Gam0XBPy_y0CTOJZeZ34c2v?usp=sharing).

View File

@@ -2441,6 +2441,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
decoder_input_ids, decoder_input_ids,
past=None, past=None,
attention_mask=None, attention_mask=None,
global_attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None, cross_attn_head_mask=None,
@@ -2458,6 +2459,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
"past_key_values": past, "past_key_values": past,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"global_attention_mask": global_attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask, "decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask, "cross_attn_head_mask": cross_attn_head_mask,