From 1f72865726f7f8ca7d0202bb8cd2e487394f8c83 Mon Sep 17 00:00:00 2001 From: dougian Date: Mon, 30 Mar 2020 17:20:37 +0100 Subject: [PATCH] [BART] Update encoder and decoder on set_input_embedding (#3501) Co-authored-by: Ioannis Douratsos --- src/transformers/modeling_bart.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index f187c31384..d237c0448d 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -805,6 +805,8 @@ class BartModel(PretrainedBartModel): def set_input_embeddings(self, value): self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared def get_output_embeddings(self): return _make_linear_from_emb(self.shared) # make it on the fly