[style] consistent nn. and nn.functional: part 3 tests (#12155)
* consistent nn. and nn.functional: p3 templates * restore
This commit is contained in:
@@ -32,6 +32,7 @@ from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, r
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import (
|
||||
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
@@ -241,7 +242,7 @@ class ReformerModelTester:
|
||||
# set all position encodings to zero so that postions don't matter
|
||||
with torch.no_grad():
|
||||
embedding = model.embeddings.position_embeddings.embedding
|
||||
embedding.weight = torch.nn.Parameter(torch.zeros(embedding.weight.shape).to(torch_device))
|
||||
embedding.weight = nn.Parameter(torch.zeros(embedding.weight.shape).to(torch_device))
|
||||
embedding.weight.requires_grad = False
|
||||
|
||||
half_seq_len = self.seq_length // 2
|
||||
|
||||
Reference in New Issue
Block a user