[style] consistent nn. and nn.functional: part 3 tests (#12155)
* consistent nn. and nn.functional: p3 templates * restore
This commit is contained in:
@@ -32,6 +32,7 @@ from .test_pipelines_common import MonoInputPipelineCommonMixin
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers.models.gpt2 import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
@@ -59,8 +60,8 @@ class SimpleConversationPipelineTests(unittest.TestCase):
|
||||
bias[76] = 1
|
||||
weight = torch.zeros((V, D), requires_grad=True)
|
||||
|
||||
model.lm_head.bias = torch.nn.Parameter(bias)
|
||||
model.lm_head.weight = torch.nn.Parameter(weight)
|
||||
model.lm_head.bias = nn.Parameter(bias)
|
||||
model.lm_head.weight = nn.Parameter(weight)
|
||||
|
||||
# # Created with:
|
||||
# import tempfile
|
||||
|
||||
Reference in New Issue
Block a user