From d9050dc768b6a8d7ef45943059d2bbe3dafc64ec Mon Sep 17 00:00:00 2001 From: Cesare Campagnano Date: Tue, 17 May 2022 23:44:37 +0200 Subject: [PATCH] [LED] fix global_attention_mask not being passed for generation and docs clarification about grad checkpointing (#17112) * [LED] fixed global_attention_mask not passed for generation + docs clarification for gradient checkpointing * LED docs clarification Co-authored-by: Patrick von Platen * [LED] gradient_checkpointing=True should be passed to TrainingArguments Co-authored-by: Patrick von Platen * [LED] docs: remove wrong word Co-authored-by: Patrick von Platen * [LED] docs fix typo Co-authored-by: Patrick von Platen Co-authored-by: Patrick von Platen --- docs/source/en/model_doc/led.mdx | 6 ++++-- src/transformers/models/led/modeling_led.py | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/source/en/model_doc/led.mdx b/docs/source/en/model_doc/led.mdx index db6b559bc2..63880d874f 100644 --- a/docs/source/en/model_doc/led.mdx +++ b/docs/source/en/model_doc/led.mdx @@ -44,8 +44,10 @@ Tips: - LED makes use of *global attention* by means of the `global_attention_mask` (see [`LongformerModel`]). For summarization, it is advised to put *global attention* only on the first `` token. For question answering, it is advised to put *global attention* on all tokens of the question. -- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by executing - `model.gradient_checkpointing_enable()`. +- To fine-tune LED on all 16384, *gradient checkpointing* can be enabled in case training leads to out-of-memory (OOM) + errors. This can be done by executing `model.gradient_checkpointing_enable()`. + Moreover, the `use_cache=False` + flag can be used to disable the caching mechanism to save memory. - A notebook showing how to evaluate LED, can be accessed [here](https://colab.research.google.com/drive/12INTTR6n64TzS4RrXZxMSXfrOd9Xzamo?usp=sharing). - A notebook showing how to fine-tune LED, can be accessed [here](https://colab.research.google.com/drive/12LjJazBl7Gam0XBPy_y0CTOJZeZ34c2v?usp=sharing). diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index ab1ffba943..ca14de32b3 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -2441,6 +2441,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): decoder_input_ids, past=None, attention_mask=None, + global_attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, @@ -2458,6 +2459,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "global_attention_mask": global_attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask,