Apply ruff flake8-comprehensions (#21694)

This commit is contained in:
Aaron Gokaslan
2023-02-22 03:14:54 -05:00
committed by GitHub
parent df06fb1f0b
commit 5e8c8eb5ba
230 changed files with 971 additions and 955 deletions

View File

@@ -127,11 +127,9 @@ def perturb_past(
_, _, _, curr_length, _ = past[0].shape
if curr_length > window_length and window_length > 0:
ones_key_val_shape = tuple(past[0].shape[:-2]) + tuple([window_length]) + tuple(past[0].shape[-1:])
ones_key_val_shape = tuple(past[0].shape[:-2]) + (window_length,) + tuple(past[0].shape[-1:])
zeros_key_val_shape = (
tuple(past[0].shape[:-2]) + tuple([curr_length - window_length]) + tuple(past[0].shape[-1:])
)
zeros_key_val_shape = tuple(past[0].shape[:-2]) + (curr_length - window_length,) + tuple(past[0].shape[-1:])
ones_mask = torch.ones(ones_key_val_shape)
ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3)