[torch] remove deprecated uint8 in favor of bool (#21384)
* uint8 -> bool * fix copies * style * update test modeling commen when checking attention buffers * style * use logical not on random mask instead of subtraction with 1 * remove torch uint8 * quality * remove modified modeling utils * Update based on review Co-authored-by: sgugger <sylvain.gugger@gmail.com> --------- Co-authored-by: sgugger <sylvain.gugger@gmail.com>
This commit is contained in:
@@ -97,7 +97,7 @@ class CodeGenAttention(nn.Module):
|
||||
max_positions = config.max_position_embeddings
|
||||
self.register_buffer(
|
||||
"causal_mask",
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
||||
1, 1, max_positions, max_positions
|
||||
),
|
||||
)
|
||||
|
||||
@@ -145,7 +145,7 @@ class XSoftmax(torch.autograd.Function):
|
||||
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
|
||||
)
|
||||
output = softmax(g, output, dim)
|
||||
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
|
||||
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.bool)))
|
||||
|
||||
|
||||
class DropoutContext(object):
|
||||
|
||||
@@ -136,7 +136,7 @@ class XSoftmax(torch.autograd.Function):
|
||||
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
|
||||
)
|
||||
output = softmax(g, output, dim)
|
||||
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
|
||||
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.bool)))
|
||||
|
||||
|
||||
# Copied from transformers.models.deberta.modeling_deberta.DropoutContext
|
||||
|
||||
@@ -115,7 +115,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
||||
max_positions = config.max_position_embeddings
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
||||
1, 1, max_positions, max_positions
|
||||
),
|
||||
)
|
||||
@@ -181,7 +181,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
||||
if not self.is_cross_attention:
|
||||
# if only "normal" attention layer implements causal mask
|
||||
query_length, key_length = query.size(-2), key.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||
|
||||
@@ -127,7 +127,7 @@ class GPT2Attention(nn.Module):
|
||||
max_positions = config.max_position_embeddings
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
||||
1, 1, max_positions, max_positions
|
||||
),
|
||||
)
|
||||
@@ -193,7 +193,7 @@ class GPT2Attention(nn.Module):
|
||||
if not self.is_cross_attention:
|
||||
# if only "normal" attention layer implements causal mask
|
||||
query_length, key_length = query.size(-2), key.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||
|
||||
@@ -133,7 +133,7 @@ class GPTNeoSelfAttention(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
max_positions = config.max_position_embeddings
|
||||
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
||||
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=bool)).view(
|
||||
1, 1, max_positions, max_positions
|
||||
)
|
||||
|
||||
@@ -187,7 +187,7 @@ class GPTNeoSelfAttention(nn.Module):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||
|
||||
query_length, key_length = query.size(-2), key.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||
|
||||
@@ -86,7 +86,7 @@ class GPTNeoXAttention(nn.Module):
|
||||
max_positions = config.max_position_embeddings
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
||||
1, 1, max_positions, max_positions
|
||||
),
|
||||
)
|
||||
@@ -193,7 +193,7 @@ class GPTNeoXAttention(nn.Module):
|
||||
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
|
||||
key_length = key.size(-2)
|
||||
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
||||
|
||||
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
|
||||
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
|
||||
|
||||
@@ -180,13 +180,13 @@ class GPTNeoXJapaneseAttention(nn.Module):
|
||||
# -> [bs, seq_len, hidden_size]
|
||||
return tensor
|
||||
|
||||
def _create_casual_mask(self, key_length, query_length):
|
||||
casual_mask = torch.tril(
|
||||
torch.ones((self.max_positions, self.max_positions), dtype=torch.uint8).view(
|
||||
def _create_causal_mask(self, key_length, query_length):
|
||||
causal_mask = torch.tril(
|
||||
torch.ones((self.max_positions, self.max_positions), dtype=torch.bool).view(
|
||||
1, 1, self.max_positions, self.max_positions
|
||||
)
|
||||
)
|
||||
return casual_mask[:, :, key_length - query_length : key_length, :key_length].bool()
|
||||
return causal_mask[:, :, key_length - query_length : key_length, :key_length]
|
||||
|
||||
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
||||
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
|
||||
@@ -194,7 +194,7 @@ class GPTNeoXJapaneseAttention(nn.Module):
|
||||
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
|
||||
key_length = key.size(-2)
|
||||
|
||||
causal_mask = self._create_casual_mask(key_length, query_length)
|
||||
causal_mask = self._create_causal_mask(key_length, query_length)
|
||||
|
||||
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
|
||||
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
|
||||
|
||||
@@ -78,7 +78,7 @@ def convert_megatron_checkpoint(sd_megatron, config):
|
||||
|
||||
pf = "model.language_model.encoder.layers."
|
||||
for i in range(layers):
|
||||
causal_mask = torch.tril(torch.ones((n_positions, n_positions), dtype=torch.uint8))
|
||||
causal_mask = torch.tril(torch.ones((n_positions, n_positions), dtype=torch.bool))
|
||||
causal_mask = causal_mask.view(1, 1, n_positions, n_positions)
|
||||
sd_hf[f"transformer.h.{i}.attn.bias"] = causal_mask
|
||||
sd_hf[f"transformer.h.{i}.attn.masked_bias"] = torch.tensor(-1e4, dtype=torch.bfloat16)
|
||||
|
||||
@@ -90,7 +90,7 @@ class GPTJAttention(nn.Module):
|
||||
max_positions = config.max_position_embeddings
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
||||
1, 1, max_positions, max_positions
|
||||
),
|
||||
)
|
||||
@@ -155,7 +155,7 @@ class GPTJAttention(nn.Module):
|
||||
):
|
||||
# compute causal mask from causal mask buffer
|
||||
query_length, key_length = query.size(-2), key.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
||||
|
||||
# Keep the attention weights computation in fp32 to avoid overflow issues
|
||||
query = query.to(torch.float32)
|
||||
|
||||
@@ -180,7 +180,7 @@ class ImageGPTAttention(nn.Module):
|
||||
max_positions = config.max_position_embeddings
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
||||
1, 1, max_positions, max_positions
|
||||
),
|
||||
)
|
||||
@@ -244,7 +244,7 @@ class ImageGPTAttention(nn.Module):
|
||||
if not self.is_cross_attention:
|
||||
# if only "normal" attention layer implements causal mask
|
||||
query_length, key_length = query.size(-2), key.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||
|
||||
@@ -71,7 +71,7 @@ def filter_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
# indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||||
indices_to_remove = torch.zeros_like(logits, dtype=torch.uint8).scatter_(
|
||||
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_(
|
||||
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
|
||||
)
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
@@ -404,12 +404,12 @@ def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=Tru
|
||||
# bool attention mask with True in locations of global attention
|
||||
attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device)
|
||||
if before_sep_token is True:
|
||||
attention_mask = (attention_mask.expand_as(input_ids) < question_end_index).to(torch.uint8)
|
||||
attention_mask = (attention_mask.expand_as(input_ids) < question_end_index).to(torch.bool)
|
||||
else:
|
||||
# last token is separation token and should not be counted and in the middle are two separation tokens
|
||||
attention_mask = (attention_mask.expand_as(input_ids) > (question_end_index + 1)).to(torch.uint8) * (
|
||||
attention_mask = (attention_mask.expand_as(input_ids) > (question_end_index + 1)).to(torch.bool) * (
|
||||
attention_mask.expand_as(input_ids) < input_ids.shape[-1]
|
||||
).to(torch.uint8)
|
||||
).to(torch.bool)
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
@@ -666,7 +666,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
||||
# add an extra bucket for padding tokens only
|
||||
num_buckets = num_buckets + 1
|
||||
# assign padding tokens extra bucket
|
||||
buckets_mask = attention_mask.to(torch.uint8)[:, None, None, :].expand(buckets.shape)
|
||||
buckets_mask = attention_mask.to(torch.bool)[:, None, None, :].expand(buckets.shape)
|
||||
buckets = torch.where(
|
||||
buckets_mask, buckets, torch.tensor(num_buckets - 1, dtype=torch.long, device=buckets.device)
|
||||
)
|
||||
@@ -841,7 +841,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
||||
# attention mask for LSH
|
||||
if attention_mask is not None:
|
||||
# if chunked attention, the attention mask has to correspond to LSH order
|
||||
attention_mask = attention_mask.to(torch.uint8)[:, None, :]
|
||||
attention_mask = attention_mask.to(torch.bool)[:, None, :]
|
||||
if not do_standard_self_attention:
|
||||
# expand attn_mask to fit with key_value_bucket_idx shape
|
||||
attention_mask = attention_mask[:, None, :]
|
||||
@@ -1225,7 +1225,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
|
||||
):
|
||||
# chunk attention mask and look before and after
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(torch.uint8)[:, None, :]
|
||||
attention_mask = attention_mask.to(torch.bool)[:, None, :]
|
||||
|
||||
if not do_standard_self_attention:
|
||||
attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1)
|
||||
@@ -2159,8 +2159,8 @@ class ReformerModel(ReformerPreTrainedModel):
|
||||
else:
|
||||
attention_mask = torch.cat(
|
||||
[
|
||||
torch.ones(input_shape, device=device, dtype=torch.uint8),
|
||||
torch.zeros((input_shape[0], padding_length), device=device, dtype=torch.uint8),
|
||||
torch.ones(input_shape, device=device, dtype=torch.bool),
|
||||
torch.zeros((input_shape[0], padding_length), device=device, dtype=torch.bool),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
@@ -566,7 +566,7 @@ class XSoftmax(torch.autograd.Function):
|
||||
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
|
||||
)
|
||||
output = softmax(g, output, dim)
|
||||
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
|
||||
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.bool)))
|
||||
|
||||
|
||||
# Copied from transformers.models.deberta.modeling_deberta.DropoutContext
|
||||
|
||||
@@ -927,7 +927,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
||||
mlen = mems[0].size(0) if mems is not None else 0
|
||||
klen = mlen + qlen
|
||||
if self.same_length:
|
||||
all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8)
|
||||
all_ones = word_emb.new_ones((qlen, klen), dtype=torch.bool)
|
||||
mask_len = klen - self.mem_len
|
||||
if mask_len > 0:
|
||||
mask_shift_len = qlen - mask_len
|
||||
@@ -935,7 +935,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
||||
mask_shift_len = qlen
|
||||
dec_attn_mask = (torch.triu(all_ones, 1 + mlen) + torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1
|
||||
else:
|
||||
dec_attn_mask = torch.triu(word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1 + mlen)[
|
||||
dec_attn_mask = torch.triu(word_emb.new_ones((qlen, klen), dtype=torch.bool), diagonal=1 + mlen)[
|
||||
:, :, None
|
||||
]
|
||||
|
||||
|
||||
@@ -442,8 +442,11 @@ class ModelTesterMixin:
|
||||
# Before we test anything
|
||||
|
||||
for key in model_fast_init.state_dict().keys():
|
||||
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-5, msg=f"{key} not identical")
|
||||
if isinstance(model_slow_init.state_dict()[key], torch.BoolTensor):
|
||||
max_diff = (model_slow_init.state_dict()[key] ^ model_fast_init.state_dict()[key]).sum().item()
|
||||
else:
|
||||
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_save_load_fast_init_to_base(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -490,10 +493,15 @@ class ModelTesterMixin:
|
||||
model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False)
|
||||
|
||||
for key in model_fast_init.state_dict().keys():
|
||||
max_diff = torch.max(
|
||||
torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key])
|
||||
).item()
|
||||
self.assertLessEqual(max_diff, 1e-5, msg=f"{key} not identical")
|
||||
if isinstance(model_slow_init.state_dict()[key], torch.BoolTensor):
|
||||
max_diff = torch.max(
|
||||
model_slow_init.state_dict()[key] ^ model_fast_init.state_dict()[key]
|
||||
).item()
|
||||
else:
|
||||
max_diff = torch.max(
|
||||
torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key])
|
||||
).item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user