Upgrade black to version ~=22.0 (#15565)
* Upgrade black to version ~=22.0 * Check copies * Fix code
This commit is contained in:
@@ -350,7 +350,7 @@ def get_grad_norm(params, scale=1):
|
||||
if p.grad is not None:
|
||||
param_norm = (p.grad.detach().data / scale).norm(2)
|
||||
total_norm += param_norm.item() ** 2
|
||||
total_norm = total_norm ** 0.5
|
||||
total_norm = total_norm**0.5
|
||||
return total_norm
|
||||
|
||||
|
||||
@@ -619,7 +619,7 @@ def main():
|
||||
|
||||
# update gumbel temperature
|
||||
gumbel_temperature = max(
|
||||
args.max_gumbel_temperature * args.gumbel_temperature_decay ** completed_steps,
|
||||
args.max_gumbel_temperature * args.gumbel_temperature_decay**completed_steps,
|
||||
args.min_gumbel_temperature,
|
||||
)
|
||||
if hasattr(model, "module"):
|
||||
|
||||
Reference in New Issue
Block a user