XLNet bias fix on resize embeddings (cf #1124)

This commit is contained in:
LysandreJik
2019-08-31 00:50:59 -04:00
parent d7a4c3252e
commit e0f867a9ba

View File

@@ -327,6 +327,14 @@ class PreTrainedModel(nn.Module):
else: else:
first_module.weight = second_module.weight first_module.weight = second_module.weight
if hasattr(first_module, 'bias'):
first_module.bias.data = torch.nn.functional.pad(
first_module.bias.data,
(0, first_module.weight.shape[0] - first_module.bias.shape[0]),
'constant',
0
)
def resize_token_embeddings(self, new_num_tokens=None): def resize_token_embeddings(self, new_num_tokens=None):
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size. """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.