[Bart/Memory] don't create lm_head (#3323)

* delete lm_head, skips weight tying
* Fixed s3
This commit is contained in:
Sam Shleifer
2020-03-26 18:40:39 -04:00
committed by GitHub
parent 5ad2ea06af
commit 39371ee454
3 changed files with 23 additions and 8 deletions

View File

@@ -804,13 +804,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
def __init__(self, config: BartConfig):
super().__init__(config)
# if base_model is None:
base_model = BartModel(config)
self.model = base_model
self.lm_head = _make_linear_from_emb(self.model.shared)
def tie_weights(self):
pass # hack to prevent changing lm_head.out_features. The input and output embeddings are still the same.
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
def forward(
@@ -875,7 +870,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
decoder_cached_states=decoder_cached_states,
generation_mode=generation_mode,
)
lm_logits = self.lm_head(outputs[0])
lm_logits = F.linear(outputs[0], self.model.shared.weight)
outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here
if lm_labels is not None:
loss_fct = nn.CrossEntropyLoss()
@@ -932,7 +927,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
return self.model.encoder
def get_output_embeddings(self):
return self.lm_head
return _make_linear_from_emb(self.model.shared) # make it on the fly
@add_start_docstrings(