[style] consistent nn. and nn.functional: part 3 tests (#12155)

* consistent nn. and nn.functional: p3 templates

* restore
This commit is contained in:
Stas Bekman
2021-06-14 12:18:22 -07:00
committed by GitHub
parent d9c0d08f9a
commit 372ab9cd6d
14 changed files with 93 additions and 81 deletions

View File

@@ -32,6 +32,7 @@ from .test_pipelines_common import MonoInputPipelineCommonMixin
if is_torch_available():
import torch
from torch import nn
from transformers.models.gpt2 import GPT2Config, GPT2LMHeadModel
@@ -59,8 +60,8 @@ class SimpleConversationPipelineTests(unittest.TestCase):
bias[76] = 1
weight = torch.zeros((V, D), requires_grad=True)
model.lm_head.bias = torch.nn.Parameter(bias)
model.lm_head.weight = torch.nn.Parameter(weight)
model.lm_head.bias = nn.Parameter(bias)
model.lm_head.weight = nn.Parameter(weight)
# # Created with:
# import tempfile