[Bart/Memory] don't create lm_head (#3323)
* delete lm_head, skips weight tying * Fixed s3
This commit is contained in:
@@ -804,13 +804,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
|
||||
def __init__(self, config: BartConfig):
|
||||
super().__init__(config)
|
||||
# if base_model is None:
|
||||
base_model = BartModel(config)
|
||||
self.model = base_model
|
||||
self.lm_head = _make_linear_from_emb(self.model.shared)
|
||||
|
||||
def tie_weights(self):
|
||||
pass # hack to prevent changing lm_head.out_features. The input and output embeddings are still the same.
|
||||
|
||||
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
@@ -875,7 +870,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
decoder_cached_states=decoder_cached_states,
|
||||
generation_mode=generation_mode,
|
||||
)
|
||||
lm_logits = self.lm_head(outputs[0])
|
||||
lm_logits = F.linear(outputs[0], self.model.shared.weight)
|
||||
outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here
|
||||
if lm_labels is not None:
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
@@ -932,7 +927,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
||||
return self.model.encoder
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
return _make_linear_from_emb(self.model.shared) # make it on the fly
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
|
||||
Reference in New Issue
Block a user