XLNet bias fix on resize embeddings (cf #1124)
This commit is contained in:
@@ -327,6 +327,14 @@ class PreTrainedModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
first_module.weight = second_module.weight
|
first_module.weight = second_module.weight
|
||||||
|
|
||||||
|
if hasattr(first_module, 'bias'):
|
||||||
|
first_module.bias.data = torch.nn.functional.pad(
|
||||||
|
first_module.bias.data,
|
||||||
|
(0, first_module.weight.shape[0] - first_module.bias.shape[0]),
|
||||||
|
'constant',
|
||||||
|
0
|
||||||
|
)
|
||||||
|
|
||||||
def resize_token_embeddings(self, new_num_tokens=None):
|
def resize_token_embeddings(self, new_num_tokens=None):
|
||||||
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
|
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
|
||||||
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
|
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
|
||||||
|
|||||||
Reference in New Issue
Block a user