Upgrade black to version ~=22.0 (#15565)
* Upgrade black to version ~=22.0 * Check copies * Fix code
This commit is contained in:
@@ -229,20 +229,14 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
assert end_logits_tea.size() == end_logits_stu.size()
|
||||
|
||||
loss_fct = nn.KLDivLoss(reduction="batchmean")
|
||||
loss_start = (
|
||||
loss_fct(
|
||||
nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
|
||||
nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
|
||||
)
|
||||
* (args.temperature ** 2)
|
||||
)
|
||||
loss_end = (
|
||||
loss_fct(
|
||||
nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
|
||||
nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
|
||||
)
|
||||
* (args.temperature ** 2)
|
||||
)
|
||||
loss_start = loss_fct(
|
||||
nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
|
||||
nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
|
||||
) * (args.temperature**2)
|
||||
loss_end = loss_fct(
|
||||
nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
|
||||
nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
|
||||
) * (args.temperature**2)
|
||||
loss_ce = (loss_start + loss_end) / 2.0
|
||||
|
||||
loss = args.alpha_ce * loss_ce + args.alpha_squad * loss
|
||||
|
||||
@@ -450,7 +450,7 @@ def main():
|
||||
negative_indices = batch.pop("sampled_negative_indices")
|
||||
|
||||
gumbel_temperature = jnp.clip(
|
||||
model_args.max_gumbel_temperature * model_args.gumbel_temperature_decay ** state.step,
|
||||
model_args.max_gumbel_temperature * model_args.gumbel_temperature_decay**state.step,
|
||||
a_min=model_args.min_gumbel_temperature,
|
||||
)
|
||||
|
||||
|
||||
@@ -1264,7 +1264,7 @@ class Res5ROIHeads(nn.Module):
|
||||
self.feature_strides = {k: v.stride for k, v in input_shape.items()}
|
||||
self.feature_channels = {k: v.channels for k, v in input_shape.items()}
|
||||
self.cls_agnostic_bbox_reg = cfg.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG
|
||||
self.stage_channel_factor = 2 ** 3 # res5 is 8x res2
|
||||
self.stage_channel_factor = 2**3 # res5 is 8x res2
|
||||
self.out_channels = cfg.RESNETS.RES2_OUT_CHANNELS * self.stage_channel_factor
|
||||
|
||||
# self.proposal_matcher = Matcher(
|
||||
@@ -1419,7 +1419,7 @@ class AnchorGenerator(nn.Module):
|
||||
|
||||
anchors = []
|
||||
for size in sizes:
|
||||
area = size ** 2.0
|
||||
area = size**2.0
|
||||
for aspect_ratio in aspect_ratios:
|
||||
w = math.sqrt(area / aspect_ratio)
|
||||
h = aspect_ratio * w
|
||||
|
||||
@@ -84,7 +84,7 @@ def schedule_threshold(
|
||||
spars_warmup_steps = initial_warmup * warmup_steps
|
||||
spars_schedu_steps = (final_warmup + initial_warmup) * warmup_steps
|
||||
mul_coeff = 1 - (step - spars_warmup_steps) / (total_step - spars_schedu_steps)
|
||||
threshold = final_threshold + (initial_threshold - final_threshold) * (mul_coeff ** 3)
|
||||
threshold = final_threshold + (initial_threshold - final_threshold) * (mul_coeff**3)
|
||||
regu_lambda = final_lambda * threshold / final_threshold
|
||||
return threshold, regu_lambda
|
||||
|
||||
@@ -285,14 +285,11 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
attention_mask=inputs["attention_mask"],
|
||||
)
|
||||
|
||||
loss_logits = (
|
||||
nn.functional.kl_div(
|
||||
input=nn.functional.log_softmax(logits_stu / args.temperature, dim=-1),
|
||||
target=nn.functional.softmax(logits_tea / args.temperature, dim=-1),
|
||||
reduction="batchmean",
|
||||
)
|
||||
* (args.temperature ** 2)
|
||||
)
|
||||
loss_logits = nn.functional.kl_div(
|
||||
input=nn.functional.log_softmax(logits_stu / args.temperature, dim=-1),
|
||||
target=nn.functional.softmax(logits_tea / args.temperature, dim=-1),
|
||||
reduction="batchmean",
|
||||
) * (args.temperature**2)
|
||||
|
||||
loss = args.alpha_distil * loss_logits + args.alpha_ce * loss
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@ def schedule_threshold(
|
||||
spars_warmup_steps = initial_warmup * warmup_steps
|
||||
spars_schedu_steps = (final_warmup + initial_warmup) * warmup_steps
|
||||
mul_coeff = 1 - (step - spars_warmup_steps) / (total_step - spars_schedu_steps)
|
||||
threshold = final_threshold + (initial_threshold - final_threshold) * (mul_coeff ** 3)
|
||||
threshold = final_threshold + (initial_threshold - final_threshold) * (mul_coeff**3)
|
||||
regu_lambda = final_lambda * threshold / final_threshold
|
||||
return threshold, regu_lambda
|
||||
|
||||
@@ -306,22 +306,16 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
attention_mask=inputs["attention_mask"],
|
||||
)
|
||||
|
||||
loss_start = (
|
||||
nn.functional.kl_div(
|
||||
input=nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
|
||||
target=nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
|
||||
reduction="batchmean",
|
||||
)
|
||||
* (args.temperature ** 2)
|
||||
)
|
||||
loss_end = (
|
||||
nn.functional.kl_div(
|
||||
input=nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
|
||||
target=nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
|
||||
reduction="batchmean",
|
||||
)
|
||||
* (args.temperature ** 2)
|
||||
)
|
||||
loss_start = nn.functional.kl_div(
|
||||
input=nn.functional.log_softmax(start_logits_stu / args.temperature, dim=-1),
|
||||
target=nn.functional.softmax(start_logits_tea / args.temperature, dim=-1),
|
||||
reduction="batchmean",
|
||||
) * (args.temperature**2)
|
||||
loss_end = nn.functional.kl_div(
|
||||
input=nn.functional.log_softmax(end_logits_stu / args.temperature, dim=-1),
|
||||
target=nn.functional.softmax(end_logits_tea / args.temperature, dim=-1),
|
||||
reduction="batchmean",
|
||||
) * (args.temperature**2)
|
||||
loss_logits = (loss_start + loss_end) / 2.0
|
||||
|
||||
loss = args.alpha_distil * loss_logits + args.alpha_ce * loss
|
||||
|
||||
@@ -442,7 +442,7 @@ class BeamSearchScorerTS(torch.nn.Module):
|
||||
elif self.do_early_stopping:
|
||||
return True
|
||||
else:
|
||||
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
|
||||
cur_score = best_sum_logprobs / cur_len**self.length_penalty
|
||||
ret = self._beam_hyps_worst_scores[hypo_idx].item() >= cur_score
|
||||
return ret
|
||||
|
||||
|
||||
@@ -550,7 +550,7 @@ def generate_text_pplm(
|
||||
|
||||
unpert_probs = nn.functional.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||
|
||||
pert_probs = (pert_probs ** gm_scale) * (unpert_probs ** (1 - gm_scale)) # + SMALL_CONST
|
||||
pert_probs = (pert_probs**gm_scale) * (unpert_probs ** (1 - gm_scale)) # + SMALL_CONST
|
||||
pert_probs = top_k_filter(pert_probs, k=top_k, probs=True) # + SMALL_CONST
|
||||
|
||||
# rescale
|
||||
|
||||
@@ -1264,7 +1264,7 @@ class Res5ROIHeads(nn.Module):
|
||||
self.feature_strides = {k: v.stride for k, v in input_shape.items()}
|
||||
self.feature_channels = {k: v.channels for k, v in input_shape.items()}
|
||||
self.cls_agnostic_bbox_reg = cfg.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG
|
||||
self.stage_channel_factor = 2 ** 3 # res5 is 8x res2
|
||||
self.stage_channel_factor = 2**3 # res5 is 8x res2
|
||||
self.out_channels = cfg.RESNETS.RES2_OUT_CHANNELS * self.stage_channel_factor
|
||||
|
||||
# self.proposal_matcher = Matcher(
|
||||
@@ -1419,7 +1419,7 @@ class AnchorGenerator(nn.Module):
|
||||
|
||||
anchors = []
|
||||
for size in sizes:
|
||||
area = size ** 2.0
|
||||
area = size**2.0
|
||||
for aspect_ratio in aspect_ratios:
|
||||
w = math.sqrt(area / aspect_ratio)
|
||||
h = aspect_ratio * w
|
||||
|
||||
@@ -273,11 +273,11 @@ class Wav2Vec2PreTrainer(Trainer):
|
||||
# make sure gumbel softmax temperature is decayed
|
||||
if self.args.n_gpu > 1 or self.deepspeed:
|
||||
model.module.set_gumbel_temperature(
|
||||
max(self.max_gumbel_temp * self.gumbel_temp_decay ** self.num_update_step, self.min_gumbel_temp)
|
||||
max(self.max_gumbel_temp * self.gumbel_temp_decay**self.num_update_step, self.min_gumbel_temp)
|
||||
)
|
||||
else:
|
||||
model.set_gumbel_temperature(
|
||||
max(self.max_gumbel_temp * self.gumbel_temp_decay ** self.num_update_step, self.min_gumbel_temp)
|
||||
max(self.max_gumbel_temp * self.gumbel_temp_decay**self.num_update_step, self.min_gumbel_temp)
|
||||
)
|
||||
|
||||
return loss.detach()
|
||||
|
||||
Reference in New Issue
Block a user