Upgrade black to version ~=22.0 (#15565)
* Upgrade black to version ~=22.0 * Check copies * Fix code
This commit is contained in:
@@ -84,7 +84,7 @@ def schedule_threshold(
|
||||
spars_warmup_steps = initial_warmup * warmup_steps
|
||||
spars_schedu_steps = (final_warmup + initial_warmup) * warmup_steps
|
||||
mul_coeff = 1 - (step - spars_warmup_steps) / (total_step - spars_schedu_steps)
|
||||
threshold = final_threshold + (initial_threshold - final_threshold) * (mul_coeff ** 3)
|
||||
threshold = final_threshold + (initial_threshold - final_threshold) * (mul_coeff**3)
|
||||
regu_lambda = final_lambda * threshold / final_threshold
|
||||
return threshold, regu_lambda
|
||||
|
||||
@@ -285,14 +285,11 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
attention_mask=inputs["attention_mask"],
|
||||
)
|
||||
|
||||
loss_logits = (
|
||||
nn.functional.kl_div(
|
||||
input=nn.functional.log_softmax(logits_stu / args.temperature, dim=-1),
|
||||
target=nn.functional.softmax(logits_tea / args.temperature, dim=-1),
|
||||
reduction="batchmean",
|
||||
)
|
||||
* (args.temperature ** 2)
|
||||
)
|
||||
loss_logits = nn.functional.kl_div(
|
||||
input=nn.functional.log_softmax(logits_stu / args.temperature, dim=-1),
|
||||
target=nn.functional.softmax(logits_tea / args.temperature, dim=-1),
|
||||
reduction="batchmean",
|
||||
) * (args.temperature**2)
|
||||
|
||||
loss = args.alpha_distil * loss_logits + args.alpha_ce * loss
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@ def schedule_threshold(
|
||||
spars_warmup_steps = initial_warmup * warmup_steps
|
||||
spars_schedu_steps = (final_warmup + initial_warmup) * warmup_steps
|
||||
mul_coeff = 1 - (step - spars_warmup_steps) / (total_step - spars_schedu_steps)
|
||||
threshold = final_threshold + (initial_threshold - final_threshold) * (mul_coeff ** 3)
|
||||
threshold = final_threshold + (initial_threshold - final_threshold) * (mul_coeff**3)
|
||||
regu_lambda = final_lambda * threshold / final_threshold
|
||||
return threshold, regu_lambda
|
||||
|
||||
@@ -306,22 +306,16 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
attention_mask=inputs["attention_mask"],
|
||||
)
|
||||
|
||||
loss_start = (
|
||||
nn.functional.kl_div(
|
||||
input=nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
|
||||
target=nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
|
||||
reduction="batchmean",
|
||||
)
|
||||
* (args.temperature ** 2)
|
||||
)
|
||||
loss_end = (
|
||||
nn.functional.kl_div(
|
||||
input=nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
|
||||
target=nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
|
||||
reduction="batchmean",
|
||||
)
|
||||
* (args.temperature ** 2)
|
||||
)
|
||||
loss_start = nn.functional.kl_div(
|
||||
input=nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
|
||||
target=nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
|
||||
reduction="batchmean",
|
||||
) * (args.temperature**2)
|
||||
loss_end = nn.functional.kl_div(
|
||||
input=nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
|
||||
target=nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
|
||||
reduction="batchmean",
|
||||
) * (args.temperature**2)
|
||||
loss_logits = (loss_start + loss_end) / 2.0
|
||||
|
||||
loss = args.alpha_distil * loss_logits + args.alpha_ce * loss
|
||||
|
||||
Reference in New Issue
Block a user