[style] consistent nn. and nn.functional: part 3 tests (#12155)
* consistent nn. and nn.functional: p3 templates * restore
This commit is contained in:
@@ -27,6 +27,7 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import TransfoXLConfig, TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel
|
||||
from transformers.models.transfo_xl.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
@@ -362,11 +363,11 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
|
||||
if hasattr(module, "emb_projs"):
|
||||
for i in range(len(module.emb_projs)):
|
||||
if module.emb_projs[i] is not None:
|
||||
torch.nn.init.constant_(module.emb_projs[i], 0.0003)
|
||||
nn.init.constant_(module.emb_projs[i], 0.0003)
|
||||
if hasattr(module, "out_projs"):
|
||||
for i in range(len(module.out_projs)):
|
||||
if module.out_projs[i] is not None:
|
||||
torch.nn.init.constant_(module.out_projs[i], 0.0003)
|
||||
nn.init.constant_(module.out_projs[i], 0.0003)
|
||||
|
||||
for param in ["r_emb", "r_w_bias", "r_r_bias", "r_bias"]:
|
||||
if hasattr(module, param) and getattr(module, param) is not None:
|
||||
|
||||
Reference in New Issue
Block a user