Optimize causal mask using torch.where (#2715)
* Optimize causal mask using torch.where Instead of multiplying by 1.0 float mask, use torch.where with a bool mask for increased performance. * Maintain compatiblity with torch 1.0.0 - thanks for PR feedback * Fix typo * reformat line for CI
This commit is contained in:
@@ -104,7 +104,10 @@ class Attention(nn.Module):
|
||||
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
||||
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
|
||||
assert n_state % config.n_head == 0
|
||||
self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
|
||||
self.register_buffer(
|
||||
"bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx)
|
||||
)
|
||||
self.register_buffer("masked_bias", torch.tensor(-1e4))
|
||||
self.n_head = config.n_head
|
||||
self.split_size = n_state
|
||||
self.scale = scale
|
||||
@@ -142,8 +145,8 @@ class Attention(nn.Module):
|
||||
if self.scale:
|
||||
w = w / math.sqrt(v.size(-1))
|
||||
nd, ns = w.size(-2), w.size(-1)
|
||||
b = self.bias[:, :, ns - nd : ns, :ns]
|
||||
w = w * b - 1e4 * (1 - b)
|
||||
mask = self.bias[:, :, ns - nd : ns, :ns]
|
||||
w = torch.where(mask, w, self.masked_bias)
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
|
||||
Reference in New Issue
Block a user