[torch] remove deprecated uint8 in favor of bool (#21384)
* uint8 -> bool * fix copies * style * update test modeling commen when checking attention buffers * style * use logical not on random mask instead of subtraction with 1 * remove torch uint8 * quality * remove modified modeling utils * Update based on review Co-authored-by: sgugger <sylvain.gugger@gmail.com> --------- Co-authored-by: sgugger <sylvain.gugger@gmail.com>
This commit is contained in:
@@ -97,7 +97,7 @@ class CodeGenAttention(nn.Module):
|
|||||||
max_positions = config.max_position_embeddings
|
max_positions = config.max_position_embeddings
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"causal_mask",
|
"causal_mask",
|
||||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
||||||
1, 1, max_positions, max_positions
|
1, 1, max_positions, max_positions
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -145,7 +145,7 @@ class XSoftmax(torch.autograd.Function):
|
|||||||
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
|
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
|
||||||
)
|
)
|
||||||
output = softmax(g, output, dim)
|
output = softmax(g, output, dim)
|
||||||
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
|
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.bool)))
|
||||||
|
|
||||||
|
|
||||||
class DropoutContext(object):
|
class DropoutContext(object):
|
||||||
|
|||||||
@@ -136,7 +136,7 @@ class XSoftmax(torch.autograd.Function):
|
|||||||
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
|
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
|
||||||
)
|
)
|
||||||
output = softmax(g, output, dim)
|
output = softmax(g, output, dim)
|
||||||
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
|
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.bool)))
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.deberta.modeling_deberta.DropoutContext
|
# Copied from transformers.models.deberta.modeling_deberta.DropoutContext
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
|||||||
max_positions = config.max_position_embeddings
|
max_positions = config.max_position_embeddings
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"bias",
|
"bias",
|
||||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
||||||
1, 1, max_positions, max_positions
|
1, 1, max_positions, max_positions
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -181,7 +181,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
|||||||
if not self.is_cross_attention:
|
if not self.is_cross_attention:
|
||||||
# if only "normal" attention layer implements causal mask
|
# if only "normal" attention layer implements causal mask
|
||||||
query_length, key_length = query.size(-2), key.size(-2)
|
query_length, key_length = query.size(-2), key.size(-2)
|
||||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
|
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
||||||
mask_value = torch.finfo(attn_weights.dtype).min
|
mask_value = torch.finfo(attn_weights.dtype).min
|
||||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||||
|
|||||||
@@ -127,7 +127,7 @@ class GPT2Attention(nn.Module):
|
|||||||
max_positions = config.max_position_embeddings
|
max_positions = config.max_position_embeddings
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"bias",
|
"bias",
|
||||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
||||||
1, 1, max_positions, max_positions
|
1, 1, max_positions, max_positions
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -193,7 +193,7 @@ class GPT2Attention(nn.Module):
|
|||||||
if not self.is_cross_attention:
|
if not self.is_cross_attention:
|
||||||
# if only "normal" attention layer implements causal mask
|
# if only "normal" attention layer implements causal mask
|
||||||
query_length, key_length = query.size(-2), key.size(-2)
|
query_length, key_length = query.size(-2), key.size(-2)
|
||||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
|
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
||||||
mask_value = torch.finfo(attn_weights.dtype).min
|
mask_value = torch.finfo(attn_weights.dtype).min
|
||||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||||
|
|||||||
@@ -133,7 +133,7 @@ class GPTNeoSelfAttention(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
max_positions = config.max_position_embeddings
|
max_positions = config.max_position_embeddings
|
||||||
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=bool)).view(
|
||||||
1, 1, max_positions, max_positions
|
1, 1, max_positions, max_positions
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -187,7 +187,7 @@ class GPTNeoSelfAttention(nn.Module):
|
|||||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||||
|
|
||||||
query_length, key_length = query.size(-2), key.size(-2)
|
query_length, key_length = query.size(-2), key.size(-2)
|
||||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
|
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
||||||
mask_value = torch.finfo(attn_weights.dtype).min
|
mask_value = torch.finfo(attn_weights.dtype).min
|
||||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
max_positions = config.max_position_embeddings
|
max_positions = config.max_position_embeddings
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"bias",
|
"bias",
|
||||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
||||||
1, 1, max_positions, max_positions
|
1, 1, max_positions, max_positions
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -193,7 +193,7 @@ class GPTNeoXAttention(nn.Module):
|
|||||||
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
|
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
|
||||||
key_length = key.size(-2)
|
key_length = key.size(-2)
|
||||||
|
|
||||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
|
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
||||||
|
|
||||||
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
|
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
|
||||||
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
|
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
|
||||||
|
|||||||
@@ -180,13 +180,13 @@ class GPTNeoXJapaneseAttention(nn.Module):
|
|||||||
# -> [bs, seq_len, hidden_size]
|
# -> [bs, seq_len, hidden_size]
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def _create_casual_mask(self, key_length, query_length):
|
def _create_causal_mask(self, key_length, query_length):
|
||||||
casual_mask = torch.tril(
|
causal_mask = torch.tril(
|
||||||
torch.ones((self.max_positions, self.max_positions), dtype=torch.uint8).view(
|
torch.ones((self.max_positions, self.max_positions), dtype=torch.bool).view(
|
||||||
1, 1, self.max_positions, self.max_positions
|
1, 1, self.max_positions, self.max_positions
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return casual_mask[:, :, key_length - query_length : key_length, :key_length].bool()
|
return causal_mask[:, :, key_length - query_length : key_length, :key_length]
|
||||||
|
|
||||||
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
||||||
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
|
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
|
||||||
@@ -194,7 +194,7 @@ class GPTNeoXJapaneseAttention(nn.Module):
|
|||||||
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
|
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
|
||||||
key_length = key.size(-2)
|
key_length = key.size(-2)
|
||||||
|
|
||||||
causal_mask = self._create_casual_mask(key_length, query_length)
|
causal_mask = self._create_causal_mask(key_length, query_length)
|
||||||
|
|
||||||
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
|
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
|
||||||
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
|
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ def convert_megatron_checkpoint(sd_megatron, config):
|
|||||||
|
|
||||||
pf = "model.language_model.encoder.layers."
|
pf = "model.language_model.encoder.layers."
|
||||||
for i in range(layers):
|
for i in range(layers):
|
||||||
causal_mask = torch.tril(torch.ones((n_positions, n_positions), dtype=torch.uint8))
|
causal_mask = torch.tril(torch.ones((n_positions, n_positions), dtype=torch.bool))
|
||||||
causal_mask = causal_mask.view(1, 1, n_positions, n_positions)
|
causal_mask = causal_mask.view(1, 1, n_positions, n_positions)
|
||||||
sd_hf[f"transformer.h.{i}.attn.bias"] = causal_mask
|
sd_hf[f"transformer.h.{i}.attn.bias"] = causal_mask
|
||||||
sd_hf[f"transformer.h.{i}.attn.masked_bias"] = torch.tensor(-1e4, dtype=torch.bfloat16)
|
sd_hf[f"transformer.h.{i}.attn.masked_bias"] = torch.tensor(-1e4, dtype=torch.bfloat16)
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ class GPTJAttention(nn.Module):
|
|||||||
max_positions = config.max_position_embeddings
|
max_positions = config.max_position_embeddings
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"bias",
|
"bias",
|
||||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
||||||
1, 1, max_positions, max_positions
|
1, 1, max_positions, max_positions
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -155,7 +155,7 @@ class GPTJAttention(nn.Module):
|
|||||||
):
|
):
|
||||||
# compute causal mask from causal mask buffer
|
# compute causal mask from causal mask buffer
|
||||||
query_length, key_length = query.size(-2), key.size(-2)
|
query_length, key_length = query.size(-2), key.size(-2)
|
||||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
|
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
||||||
|
|
||||||
# Keep the attention weights computation in fp32 to avoid overflow issues
|
# Keep the attention weights computation in fp32 to avoid overflow issues
|
||||||
query = query.to(torch.float32)
|
query = query.to(torch.float32)
|
||||||
|
|||||||
@@ -180,7 +180,7 @@ class ImageGPTAttention(nn.Module):
|
|||||||
max_positions = config.max_position_embeddings
|
max_positions = config.max_position_embeddings
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"bias",
|
"bias",
|
||||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
||||||
1, 1, max_positions, max_positions
|
1, 1, max_positions, max_positions
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -244,7 +244,7 @@ class ImageGPTAttention(nn.Module):
|
|||||||
if not self.is_cross_attention:
|
if not self.is_cross_attention:
|
||||||
# if only "normal" attention layer implements causal mask
|
# if only "normal" attention layer implements causal mask
|
||||||
query_length, key_length = query.size(-2), key.size(-2)
|
query_length, key_length = query.size(-2), key.size(-2)
|
||||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
|
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
||||||
mask_value = torch.finfo(attn_weights.dtype).min
|
mask_value = torch.finfo(attn_weights.dtype).min
|
||||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ def filter_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
|
|||||||
sorted_indices_to_remove[..., 0] = 0
|
sorted_indices_to_remove[..., 0] = 0
|
||||||
|
|
||||||
# indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
# indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||||||
indices_to_remove = torch.zeros_like(logits, dtype=torch.uint8).scatter_(
|
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_(
|
||||||
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
|
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
|
||||||
)
|
)
|
||||||
logits[indices_to_remove] = filter_value
|
logits[indices_to_remove] = filter_value
|
||||||
|
|||||||
@@ -404,12 +404,12 @@ def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=Tru
|
|||||||
# bool attention mask with True in locations of global attention
|
# bool attention mask with True in locations of global attention
|
||||||
attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device)
|
attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device)
|
||||||
if before_sep_token is True:
|
if before_sep_token is True:
|
||||||
attention_mask = (attention_mask.expand_as(input_ids) < question_end_index).to(torch.uint8)
|
attention_mask = (attention_mask.expand_as(input_ids) < question_end_index).to(torch.bool)
|
||||||
else:
|
else:
|
||||||
# last token is separation token and should not be counted and in the middle are two separation tokens
|
# last token is separation token and should not be counted and in the middle are two separation tokens
|
||||||
attention_mask = (attention_mask.expand_as(input_ids) > (question_end_index + 1)).to(torch.uint8) * (
|
attention_mask = (attention_mask.expand_as(input_ids) > (question_end_index + 1)).to(torch.bool) * (
|
||||||
attention_mask.expand_as(input_ids) < input_ids.shape[-1]
|
attention_mask.expand_as(input_ids) < input_ids.shape[-1]
|
||||||
).to(torch.uint8)
|
).to(torch.bool)
|
||||||
|
|
||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
|
|||||||
@@ -666,7 +666,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
|||||||
# add an extra bucket for padding tokens only
|
# add an extra bucket for padding tokens only
|
||||||
num_buckets = num_buckets + 1
|
num_buckets = num_buckets + 1
|
||||||
# assign padding tokens extra bucket
|
# assign padding tokens extra bucket
|
||||||
buckets_mask = attention_mask.to(torch.uint8)[:, None, None, :].expand(buckets.shape)
|
buckets_mask = attention_mask.to(torch.bool)[:, None, None, :].expand(buckets.shape)
|
||||||
buckets = torch.where(
|
buckets = torch.where(
|
||||||
buckets_mask, buckets, torch.tensor(num_buckets - 1, dtype=torch.long, device=buckets.device)
|
buckets_mask, buckets, torch.tensor(num_buckets - 1, dtype=torch.long, device=buckets.device)
|
||||||
)
|
)
|
||||||
@@ -841,7 +841,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
|
|||||||
# attention mask for LSH
|
# attention mask for LSH
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# if chunked attention, the attention mask has to correspond to LSH order
|
# if chunked attention, the attention mask has to correspond to LSH order
|
||||||
attention_mask = attention_mask.to(torch.uint8)[:, None, :]
|
attention_mask = attention_mask.to(torch.bool)[:, None, :]
|
||||||
if not do_standard_self_attention:
|
if not do_standard_self_attention:
|
||||||
# expand attn_mask to fit with key_value_bucket_idx shape
|
# expand attn_mask to fit with key_value_bucket_idx shape
|
||||||
attention_mask = attention_mask[:, None, :]
|
attention_mask = attention_mask[:, None, :]
|
||||||
@@ -1225,7 +1225,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
|
|||||||
):
|
):
|
||||||
# chunk attention mask and look before and after
|
# chunk attention mask and look before and after
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask.to(torch.uint8)[:, None, :]
|
attention_mask = attention_mask.to(torch.bool)[:, None, :]
|
||||||
|
|
||||||
if not do_standard_self_attention:
|
if not do_standard_self_attention:
|
||||||
attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1)
|
attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1)
|
||||||
@@ -2159,8 +2159,8 @@ class ReformerModel(ReformerPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
attention_mask = torch.cat(
|
attention_mask = torch.cat(
|
||||||
[
|
[
|
||||||
torch.ones(input_shape, device=device, dtype=torch.uint8),
|
torch.ones(input_shape, device=device, dtype=torch.bool),
|
||||||
torch.zeros((input_shape[0], padding_length), device=device, dtype=torch.uint8),
|
torch.zeros((input_shape[0], padding_length), device=device, dtype=torch.bool),
|
||||||
],
|
],
|
||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -566,7 +566,7 @@ class XSoftmax(torch.autograd.Function):
|
|||||||
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
|
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
|
||||||
)
|
)
|
||||||
output = softmax(g, output, dim)
|
output = softmax(g, output, dim)
|
||||||
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
|
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.bool)))
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.deberta.modeling_deberta.DropoutContext
|
# Copied from transformers.models.deberta.modeling_deberta.DropoutContext
|
||||||
|
|||||||
@@ -927,7 +927,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
|||||||
mlen = mems[0].size(0) if mems is not None else 0
|
mlen = mems[0].size(0) if mems is not None else 0
|
||||||
klen = mlen + qlen
|
klen = mlen + qlen
|
||||||
if self.same_length:
|
if self.same_length:
|
||||||
all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8)
|
all_ones = word_emb.new_ones((qlen, klen), dtype=torch.bool)
|
||||||
mask_len = klen - self.mem_len
|
mask_len = klen - self.mem_len
|
||||||
if mask_len > 0:
|
if mask_len > 0:
|
||||||
mask_shift_len = qlen - mask_len
|
mask_shift_len = qlen - mask_len
|
||||||
@@ -935,7 +935,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
|||||||
mask_shift_len = qlen
|
mask_shift_len = qlen
|
||||||
dec_attn_mask = (torch.triu(all_ones, 1 + mlen) + torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1
|
dec_attn_mask = (torch.triu(all_ones, 1 + mlen) + torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1
|
||||||
else:
|
else:
|
||||||
dec_attn_mask = torch.triu(word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1 + mlen)[
|
dec_attn_mask = torch.triu(word_emb.new_ones((qlen, klen), dtype=torch.bool), diagonal=1 + mlen)[
|
||||||
:, :, None
|
:, :, None
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -442,8 +442,11 @@ class ModelTesterMixin:
|
|||||||
# Before we test anything
|
# Before we test anything
|
||||||
|
|
||||||
for key in model_fast_init.state_dict().keys():
|
for key in model_fast_init.state_dict().keys():
|
||||||
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
if isinstance(model_slow_init.state_dict()[key], torch.BoolTensor):
|
||||||
self.assertLessEqual(max_diff, 1e-5, msg=f"{key} not identical")
|
max_diff = (model_slow_init.state_dict()[key] ^ model_fast_init.state_dict()[key]).sum().item()
|
||||||
|
else:
|
||||||
|
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
||||||
|
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||||
|
|
||||||
def test_save_load_fast_init_to_base(self):
|
def test_save_load_fast_init_to_base(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@@ -490,10 +493,15 @@ class ModelTesterMixin:
|
|||||||
model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False)
|
model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False)
|
||||||
|
|
||||||
for key in model_fast_init.state_dict().keys():
|
for key in model_fast_init.state_dict().keys():
|
||||||
max_diff = torch.max(
|
if isinstance(model_slow_init.state_dict()[key], torch.BoolTensor):
|
||||||
torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key])
|
max_diff = torch.max(
|
||||||
).item()
|
model_slow_init.state_dict()[key] ^ model_fast_init.state_dict()[key]
|
||||||
self.assertLessEqual(max_diff, 1e-5, msg=f"{key} not identical")
|
).item()
|
||||||
|
else:
|
||||||
|
max_diff = torch.max(
|
||||||
|
torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key])
|
||||||
|
).item()
|
||||||
|
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||||
|
|
||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
Reference in New Issue
Block a user