[style] consistent nn. and nn.functional: part 3 tests (#12155)

* consistent nn. and nn.functional: p3 templates

* restore
This commit is contained in:
Stas Bekman
2021-06-14 12:18:22 -07:00
committed by GitHub
parent d9c0d08f9a
commit 372ab9cd6d
14 changed files with 93 additions and 81 deletions

View File

@@ -32,6 +32,7 @@ from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, r
if is_torch_available():
import torch
from torch import nn
from transformers import (
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
@@ -241,7 +242,7 @@ class ReformerModelTester:
# set all position encodings to zero so that postions don't matter
with torch.no_grad():
embedding = model.embeddings.position_embeddings.embedding
embedding.weight = torch.nn.Parameter(torch.zeros(embedding.weight.shape).to(torch_device))
embedding.weight = nn.Parameter(torch.zeros(embedding.weight.shape).to(torch_device))
embedding.weight.requires_grad = False
half_seq_len = self.seq_length // 2