Apply ruff flake8-comprehensions (#21694)
This commit is contained in:
@@ -127,11 +127,9 @@ def perturb_past(
|
||||
_, _, _, curr_length, _ = past[0].shape
|
||||
|
||||
if curr_length > window_length and window_length > 0:
|
||||
ones_key_val_shape = tuple(past[0].shape[:-2]) + tuple([window_length]) + tuple(past[0].shape[-1:])
|
||||
ones_key_val_shape = tuple(past[0].shape[:-2]) + (window_length,) + tuple(past[0].shape[-1:])
|
||||
|
||||
zeros_key_val_shape = (
|
||||
tuple(past[0].shape[:-2]) + tuple([curr_length - window_length]) + tuple(past[0].shape[-1:])
|
||||
)
|
||||
zeros_key_val_shape = tuple(past[0].shape[:-2]) + (curr_length - window_length,) + tuple(past[0].shape[-1:])
|
||||
|
||||
ones_mask = torch.ones(ones_key_val_shape)
|
||||
ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3)
|
||||
|
||||
Reference in New Issue
Block a user