fix pegasus init weights and other copied models (#36844)
* fix pegasus init weights Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix the rest of models Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix test Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix informer init Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * init weight before checking Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix roformer tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix roformer tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
This commit is contained in:
@@ -360,7 +360,6 @@ class AutoformerSinusoidalPositionalEmbedding(nn.Embedding):
|
||||
|
||||
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
|
||||
super().__init__(num_positions, embedding_dim)
|
||||
self.weight = self._init_weight(self.weight)
|
||||
|
||||
@staticmethod
|
||||
def _init_weight(out: nn.Parameter) -> nn.Parameter:
|
||||
@@ -904,7 +903,7 @@ class AutoformerPreTrainedModel(PreTrainedModel):
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, AutoformerSinusoidalPositionalEmbedding):
|
||||
pass
|
||||
module.weight = module._init_weight(module.weight)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
|
||||
@@ -233,7 +233,6 @@ class InformerSinusoidalPositionalEmbedding(nn.Embedding):
|
||||
|
||||
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
|
||||
super().__init__(num_positions, embedding_dim)
|
||||
self.weight = self._init_weight(self.weight)
|
||||
|
||||
@staticmethod
|
||||
def _init_weight(out: nn.Parameter) -> nn.Parameter:
|
||||
@@ -887,7 +886,9 @@ class InformerPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding) and not isinstance(module, InformerSinusoidalPositionalEmbedding):
|
||||
elif isinstance(module, InformerSinusoidalPositionalEmbedding):
|
||||
module.weight = module._init_weight(module.weight)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
@@ -73,7 +73,6 @@ class MarianSinusoidalPositionalEmbedding(nn.Embedding):
|
||||
|
||||
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
|
||||
super().__init__(num_positions, embedding_dim)
|
||||
self.weight = self._init_weight(self.weight)
|
||||
|
||||
@staticmethod
|
||||
def _init_weight(out: nn.Parameter) -> nn.Parameter:
|
||||
@@ -468,7 +467,7 @@ class MarianPreTrainedModel(PreTrainedModel):
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, MarianSinusoidalPositionalEmbedding):
|
||||
pass
|
||||
module.weight = module._init_weight(module.weight)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
|
||||
@@ -74,7 +74,6 @@ class PegasusSinusoidalPositionalEmbedding(nn.Embedding):
|
||||
|
||||
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
|
||||
super().__init__(num_positions, embedding_dim)
|
||||
self.weight = self._init_weight(self.weight)
|
||||
|
||||
@staticmethod
|
||||
def _init_weight(out: nn.Parameter) -> nn.Parameter:
|
||||
@@ -469,7 +468,7 @@ class PegasusPreTrainedModel(PreTrainedModel):
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, PegasusSinusoidalPositionalEmbedding):
|
||||
pass
|
||||
module.weight = module._init_weight(module.weight)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
@@ -665,6 +664,7 @@ class PegasusEncoder(PegasusPreTrainedModel):
|
||||
self.config.d_model,
|
||||
self.padding_idx,
|
||||
)
|
||||
self.embed_positions.weight = self.embed_positions._init_weight(self.embed_positions.weight)
|
||||
self.embed_positions.to(self.device)
|
||||
|
||||
def get_position_embeddings(self) -> nn.Embedding:
|
||||
@@ -868,6 +868,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
|
||||
self.config.d_model,
|
||||
self.padding_idx,
|
||||
)
|
||||
self.embed_positions.weight = self.embed_positions._init_weight(self.embed_positions.weight)
|
||||
self.embed_positions.to(self.device)
|
||||
|
||||
def get_position_embeddings(self) -> nn.Embedding:
|
||||
|
||||
@@ -59,7 +59,6 @@ class RoFormerSinusoidalPositionalEmbedding(nn.Embedding):
|
||||
|
||||
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
|
||||
super().__init__(num_positions, embedding_dim)
|
||||
self.weight = self._init_weight(self.weight)
|
||||
|
||||
@staticmethod
|
||||
def _init_weight(out: nn.Parameter) -> nn.Parameter:
|
||||
@@ -694,7 +693,7 @@ class RoFormerPreTrainedModel(PreTrainedModel):
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, RoFormerSinusoidalPositionalEmbedding):
|
||||
pass
|
||||
module.weight = module._init_weight(module.weight)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
|
||||
@@ -233,7 +233,6 @@ class TimeSeriesSinusoidalPositionalEmbedding(nn.Embedding):
|
||||
|
||||
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
|
||||
super().__init__(num_positions, embedding_dim)
|
||||
self.weight = self._init_weight(self.weight)
|
||||
|
||||
@staticmethod
|
||||
def _init_weight(out: nn.Parameter) -> nn.Parameter:
|
||||
@@ -641,7 +640,7 @@ class TimeSeriesTransformerPreTrainedModel(PreTrainedModel):
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, TimeSeriesSinusoidalPositionalEmbedding):
|
||||
pass
|
||||
module.weight = module._init_weight(module.weight)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
|
||||
Reference in New Issue
Block a user