Fixing conversation test for torch 1.8 (#10545)
This commit is contained in:
@@ -53,9 +53,9 @@ class SimpleConversationPipelineTests(unittest.TestCase):
|
||||
model = GPT2LMHeadModel(config)
|
||||
# Force model output to be L
|
||||
V, D = model.lm_head.weight.shape
|
||||
bias = torch.zeros(V, requires_grad=True)
|
||||
weight = torch.zeros((V, D), requires_grad=True)
|
||||
bias = torch.zeros(V)
|
||||
bias[76] = 1
|
||||
weight = torch.zeros((V, D), requires_grad=True)
|
||||
|
||||
model.lm_head.bias = torch.nn.Parameter(bias)
|
||||
model.lm_head.weight = torch.nn.Parameter(weight)
|
||||
|
||||
Reference in New Issue
Block a user