Upgrade black to version ~=22.0 (#15565)
* Upgrade black to version ~=22.0 * Check copies * Fix code
This commit is contained in:
@@ -273,11 +273,11 @@ class Wav2Vec2PreTrainer(Trainer):
|
||||
# make sure gumbel softmax temperature is decayed
|
||||
if self.args.n_gpu > 1 or self.deepspeed:
|
||||
model.module.set_gumbel_temperature(
|
||||
max(self.max_gumbel_temp * self.gumbel_temp_decay ** self.num_update_step, self.min_gumbel_temp)
|
||||
max(self.max_gumbel_temp * self.gumbel_temp_decay**self.num_update_step, self.min_gumbel_temp)
|
||||
)
|
||||
else:
|
||||
model.set_gumbel_temperature(
|
||||
max(self.max_gumbel_temp * self.gumbel_temp_decay ** self.num_update_step, self.min_gumbel_temp)
|
||||
max(self.max_gumbel_temp * self.gumbel_temp_decay**self.num_update_step, self.min_gumbel_temp)
|
||||
)
|
||||
|
||||
return loss.detach()
|
||||
|
||||
Reference in New Issue
Block a user