Upgrade black to version ~=22.0 (#15565)
* Upgrade black to version ~=22.0 * Check copies * Fix code
This commit is contained in:
@@ -229,20 +229,14 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
assert end_logits_tea.size() == end_logits_stu.size()
|
||||
|
||||
loss_fct = nn.KLDivLoss(reduction="batchmean")
|
||||
loss_start = (
|
||||
loss_fct(
|
||||
nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
|
||||
nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
|
||||
)
|
||||
* (args.temperature ** 2)
|
||||
)
|
||||
loss_end = (
|
||||
loss_fct(
|
||||
nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
|
||||
nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
|
||||
)
|
||||
* (args.temperature ** 2)
|
||||
)
|
||||
loss_start = loss_fct(
|
||||
nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
|
||||
nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
|
||||
) * (args.temperature**2)
|
||||
loss_end = loss_fct(
|
||||
nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
|
||||
nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
|
||||
) * (args.temperature**2)
|
||||
loss_ce = (loss_start + loss_end) / 2.0
|
||||
|
||||
loss = args.alpha_ce * loss_ce + args.alpha_squad * loss
|
||||
|
||||
Reference in New Issue
Block a user