Fixing conversation test for torch 1.8 (#10545)
This commit is contained in:
@@ -53,9 +53,9 @@ class SimpleConversationPipelineTests(unittest.TestCase):
|
|||||||
model = GPT2LMHeadModel(config)
|
model = GPT2LMHeadModel(config)
|
||||||
# Force model output to be L
|
# Force model output to be L
|
||||||
V, D = model.lm_head.weight.shape
|
V, D = model.lm_head.weight.shape
|
||||||
bias = torch.zeros(V, requires_grad=True)
|
bias = torch.zeros(V)
|
||||||
weight = torch.zeros((V, D), requires_grad=True)
|
|
||||||
bias[76] = 1
|
bias[76] = 1
|
||||||
|
weight = torch.zeros((V, D), requires_grad=True)
|
||||||
|
|
||||||
model.lm_head.bias = torch.nn.Parameter(bias)
|
model.lm_head.bias = torch.nn.Parameter(bias)
|
||||||
model.lm_head.weight = torch.nn.Parameter(weight)
|
model.lm_head.weight = torch.nn.Parameter(weight)
|
||||||
|
|||||||
Reference in New Issue
Block a user