[LED] fix global_attention_mask not being passed for generation and docs clarification about grad checkpointing (#17112)
* [LED] fixed global_attention_mask not passed for generation + docs clarification for gradient checkpointing * LED docs clarification Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * [LED] gradient_checkpointing=True should be passed to TrainingArguments Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * [LED] docs: remove wrong word Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * [LED] docs fix typo Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
committed by
GitHub
parent
bad358398a
commit
d9050dc768
@@ -44,8 +44,10 @@ Tips:
|
|||||||
- LED makes use of *global attention* by means of the `global_attention_mask` (see
|
- LED makes use of *global attention* by means of the `global_attention_mask` (see
|
||||||
[`LongformerModel`]). For summarization, it is advised to put *global attention* only on the first
|
[`LongformerModel`]). For summarization, it is advised to put *global attention* only on the first
|
||||||
`<s>` token. For question answering, it is advised to put *global attention* on all tokens of the question.
|
`<s>` token. For question answering, it is advised to put *global attention* on all tokens of the question.
|
||||||
- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by executing
|
- To fine-tune LED on all 16384, *gradient checkpointing* can be enabled in case training leads to out-of-memory (OOM)
|
||||||
`model.gradient_checkpointing_enable()`.
|
errors. This can be done by executing `model.gradient_checkpointing_enable()`.
|
||||||
|
Moreover, the `use_cache=False`
|
||||||
|
flag can be used to disable the caching mechanism to save memory.
|
||||||
- A notebook showing how to evaluate LED, can be accessed [here](https://colab.research.google.com/drive/12INTTR6n64TzS4RrXZxMSXfrOd9Xzamo?usp=sharing).
|
- A notebook showing how to evaluate LED, can be accessed [here](https://colab.research.google.com/drive/12INTTR6n64TzS4RrXZxMSXfrOd9Xzamo?usp=sharing).
|
||||||
- A notebook showing how to fine-tune LED, can be accessed [here](https://colab.research.google.com/drive/12LjJazBl7Gam0XBPy_y0CTOJZeZ34c2v?usp=sharing).
|
- A notebook showing how to fine-tune LED, can be accessed [here](https://colab.research.google.com/drive/12LjJazBl7Gam0XBPy_y0CTOJZeZ34c2v?usp=sharing).
|
||||||
|
|
||||||
|
|||||||
@@ -2441,6 +2441,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
|
|||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
past=None,
|
past=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
global_attention_mask=None,
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
decoder_head_mask=None,
|
decoder_head_mask=None,
|
||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask=None,
|
||||||
@@ -2458,6 +2459,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
|
|||||||
"past_key_values": past,
|
"past_key_values": past,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
|
"global_attention_mask": global_attention_mask,
|
||||||
"head_mask": head_mask,
|
"head_mask": head_mask,
|
||||||
"decoder_head_mask": decoder_head_mask,
|
"decoder_head_mask": decoder_head_mask,
|
||||||
"cross_attn_head_mask": cross_attn_head_mask,
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
|
|||||||
Reference in New Issue
Block a user