Upgrade black to version ~=22.0 (#15565)
* Upgrade black to version ~=22.0 * Check copies * Fix code
This commit is contained in:
@@ -450,7 +450,7 @@ def main():
|
||||
negative_indices = batch.pop("sampled_negative_indices")
|
||||
|
||||
gumbel_temperature = jnp.clip(
|
||||
model_args.max_gumbel_temperature * model_args.gumbel_temperature_decay ** state.step,
|
||||
model_args.max_gumbel_temperature * model_args.gumbel_temperature_decay**state.step,
|
||||
a_min=model_args.min_gumbel_temperature,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user