[cleanup] factor out get_head_mask, invert_attn_mask, get_exten… (#3806)

* Delete some copy pasted code
This commit is contained in:
Sam Shleifer
2020-04-16 09:55:25 -04:00
committed by GitHub
parent d22894dfd4
commit dbd041243d
13 changed files with 132 additions and 379 deletions

View File

@@ -392,26 +392,11 @@ class CTRLModel(CTRLPreTrainedModel):
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = (
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
) # We can specify head_mask for each layer
head_mask = head_mask.to(
dtype=next(self.parameters()).dtype
) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.config.n_layer
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])