From 05deb52dc170226fcad899c240ee9568986d8d5a Mon Sep 17 00:00:00 2001 From: Michael Pang Date: Tue, 7 Apr 2020 15:19:18 -0500 Subject: [PATCH] Optimize causal mask using torch.where (#2715) * Optimize causal mask using torch.where Instead of multiplying by 1.0 float mask, use torch.where with a bool mask for increased performance. * Maintain compatiblity with torch 1.0.0 - thanks for PR feedback * Fix typo * reformat line for CI --- src/transformers/modeling_gpt2.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index 94fb3ac1db..c89fc46113 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -104,7 +104,10 @@ class Attention(nn.Module): n_state = nx # in Attention: n_state=768 (nx=n_embd) # [switch nx => n_state from Block to Attention to keep identical to TF implem] assert n_state % config.n_head == 0 - self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) + self.register_buffer( + "bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx) + ) + self.register_buffer("masked_bias", torch.tensor(-1e4)) self.n_head = config.n_head self.split_size = n_state self.scale = scale @@ -142,8 +145,8 @@ class Attention(nn.Module): if self.scale: w = w / math.sqrt(v.size(-1)) nd, ns = w.size(-2), w.size(-1) - b = self.bias[:, :, ns - nd : ns, :ns] - w = w * b - 1e4 * (1 - b) + mask = self.bias[:, :, ns - nd : ns, :ns] + w = torch.where(mask, w, self.masked_bias) if attention_mask is not None: # Apply the attention mask