From c51dc4f92755c67a83f3fc8a0bd6b3e64df199e4 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Mon, 27 Feb 2023 11:46:02 +0100 Subject: [PATCH] [torch] remove deprecated uint8 in favor of bool (#21384) * uint8 -> bool * fix copies * style * update test modeling commen when checking attention buffers * style * use logical not on random mask instead of subtraction with 1 * remove torch uint8 * quality * remove modified modeling utils * Update based on review Co-authored-by: sgugger --------- Co-authored-by: sgugger --- .../models/codegen/modeling_codegen.py | 2 +- .../models/deberta/modeling_deberta.py | 2 +- .../models/deberta_v2/modeling_deberta_v2.py | 2 +- .../modeling_decision_transformer.py | 4 ++-- src/transformers/models/gpt2/modeling_gpt2.py | 4 ++-- .../models/gpt_neo/modeling_gpt_neo.py | 4 ++-- .../models/gpt_neox/modeling_gpt_neox.py | 4 ++-- .../modeling_gpt_neox_japanese.py | 10 +++++----- .../gpt_sw3/convert_megatron_to_pytorch.py | 2 +- src/transformers/models/gptj/modeling_gptj.py | 4 ++-- .../models/imagegpt/modeling_imagegpt.py | 4 ++-- .../models/jukebox/modeling_jukebox.py | 2 +- .../models/longformer/modeling_longformer.py | 6 +++--- .../models/reformer/modeling_reformer.py | 10 +++++----- .../models/sew_d/modeling_sew_d.py | 2 +- .../models/transfo_xl/modeling_transfo_xl.py | 4 ++-- tests/test_modeling_common.py | 20 +++++++++++++------ 17 files changed, 47 insertions(+), 39 deletions(-) diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index b564dcdb68..936e40656f 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -97,7 +97,7 @@ class CodeGenAttention(nn.Module): max_positions = config.max_position_embeddings self.register_buffer( "causal_mask", - torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( 1, 1, max_positions, max_positions ), ) diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 2b63b428bb..7c98a2d0d4 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -145,7 +145,7 @@ class XSoftmax(torch.autograd.Function): g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min)) ) output = softmax(g, output, dim) - return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8))) + return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.bool))) class DropoutContext(object): diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index ef04a24b2f..cd6e6318e3 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -136,7 +136,7 @@ class XSoftmax(torch.autograd.Function): g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min)) ) output = softmax(g, output, dim) - return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8))) + return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.bool))) # Copied from transformers.models.deberta.modeling_deberta.DropoutContext diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 18257236a3..3ca52a250f 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -115,7 +115,7 @@ class DecisionTransformerGPT2Attention(nn.Module): max_positions = config.max_position_embeddings self.register_buffer( "bias", - torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( 1, 1, max_positions, max_positions ), ) @@ -181,7 +181,7 @@ class DecisionTransformerGPT2Attention(nn.Module): if not self.is_cross_attention: # if only "normal" attention layer implements causal mask query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] mask_value = torch.finfo(attn_weights.dtype).min # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index cde13f7bdb..653762794a 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -127,7 +127,7 @@ class GPT2Attention(nn.Module): max_positions = config.max_position_embeddings self.register_buffer( "bias", - torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( 1, 1, max_positions, max_positions ), ) @@ -193,7 +193,7 @@ class GPT2Attention(nn.Module): if not self.is_cross_attention: # if only "normal" attention layer implements causal mask query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] mask_value = torch.finfo(attn_weights.dtype).min # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 3391c9f116..9f23f3cbef 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -133,7 +133,7 @@ class GPTNeoSelfAttention(nn.Module): super().__init__() max_positions = config.max_position_embeddings - bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + bias = torch.tril(torch.ones((max_positions, max_positions), dtype=bool)).view( 1, 1, max_positions, max_positions ) @@ -187,7 +187,7 @@ class GPTNeoSelfAttention(nn.Module): attn_weights = torch.matmul(query, key.transpose(-1, -2)) query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] mask_value = torch.finfo(attn_weights.dtype).min # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index b624e0a749..35eed6e808 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -86,7 +86,7 @@ class GPTNeoXAttention(nn.Module): max_positions = config.max_position_embeddings self.register_buffer( "bias", - torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( 1, 1, max_positions, max_positions ), ) @@ -193,7 +193,7 @@ class GPTNeoXAttention(nn.Module): batch_size, num_attention_heads, query_length, attn_head_size = query.size() key_length = key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index cc3e7bd2c9..388d9b3d52 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -180,13 +180,13 @@ class GPTNeoXJapaneseAttention(nn.Module): # -> [bs, seq_len, hidden_size] return tensor - def _create_casual_mask(self, key_length, query_length): - casual_mask = torch.tril( - torch.ones((self.max_positions, self.max_positions), dtype=torch.uint8).view( + def _create_causal_mask(self, key_length, query_length): + causal_mask = torch.tril( + torch.ones((self.max_positions, self.max_positions), dtype=torch.bool).view( 1, 1, self.max_positions, self.max_positions ) ) - return casual_mask[:, :, key_length - query_length : key_length, :key_length].bool() + return causal_mask[:, :, key_length - query_length : key_length, :key_length] def _attn(self, query, key, value, attention_mask=None, head_mask=None): # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] @@ -194,7 +194,7 @@ class GPTNeoXJapaneseAttention(nn.Module): batch_size, num_attention_heads, query_length, attn_head_size = query.size() key_length = key.size(-2) - causal_mask = self._create_casual_mask(key_length, query_length) + causal_mask = self._create_causal_mask(key_length, query_length) query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) diff --git a/src/transformers/models/gpt_sw3/convert_megatron_to_pytorch.py b/src/transformers/models/gpt_sw3/convert_megatron_to_pytorch.py index 13160f77c1..5562efa287 100644 --- a/src/transformers/models/gpt_sw3/convert_megatron_to_pytorch.py +++ b/src/transformers/models/gpt_sw3/convert_megatron_to_pytorch.py @@ -78,7 +78,7 @@ def convert_megatron_checkpoint(sd_megatron, config): pf = "model.language_model.encoder.layers." for i in range(layers): - causal_mask = torch.tril(torch.ones((n_positions, n_positions), dtype=torch.uint8)) + causal_mask = torch.tril(torch.ones((n_positions, n_positions), dtype=torch.bool)) causal_mask = causal_mask.view(1, 1, n_positions, n_positions) sd_hf[f"transformer.h.{i}.attn.bias"] = causal_mask sd_hf[f"transformer.h.{i}.attn.masked_bias"] = torch.tensor(-1e4, dtype=torch.bfloat16) diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index f9c49db52d..14f6979d78 100755 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -90,7 +90,7 @@ class GPTJAttention(nn.Module): max_positions = config.max_position_embeddings self.register_buffer( "bias", - torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( 1, 1, max_positions, max_positions ), ) @@ -155,7 +155,7 @@ class GPTJAttention(nn.Module): ): # compute causal mask from causal mask buffer query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] # Keep the attention weights computation in fp32 to avoid overflow issues query = query.to(torch.float32) diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 21305de732..acd50d2be8 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -180,7 +180,7 @@ class ImageGPTAttention(nn.Module): max_positions = config.max_position_embeddings self.register_buffer( "bias", - torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( 1, 1, max_positions, max_positions ), ) @@ -244,7 +244,7 @@ class ImageGPTAttention(nn.Module): if not self.is_cross_attention: # if only "normal" attention layer implements causal mask query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] mask_value = torch.finfo(attn_weights.dtype).min # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` diff --git a/src/transformers/models/jukebox/modeling_jukebox.py b/src/transformers/models/jukebox/modeling_jukebox.py index cac9300539..f7be47c005 100755 --- a/src/transformers/models/jukebox/modeling_jukebox.py +++ b/src/transformers/models/jukebox/modeling_jukebox.py @@ -71,7 +71,7 @@ def filter_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): sorted_indices_to_remove[..., 0] = 0 # indices_to_remove = sorted_indices[sorted_indices_to_remove] - indices_to_remove = torch.zeros_like(logits, dtype=torch.uint8).scatter_( + indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_( dim=-1, index=sorted_indices, src=sorted_indices_to_remove ) logits[indices_to_remove] = filter_value diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 6ad6bdad67..6a16b72c79 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -404,12 +404,12 @@ def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=Tru # bool attention mask with True in locations of global attention attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device) if before_sep_token is True: - attention_mask = (attention_mask.expand_as(input_ids) < question_end_index).to(torch.uint8) + attention_mask = (attention_mask.expand_as(input_ids) < question_end_index).to(torch.bool) else: # last token is separation token and should not be counted and in the middle are two separation tokens - attention_mask = (attention_mask.expand_as(input_ids) > (question_end_index + 1)).to(torch.uint8) * ( + attention_mask = (attention_mask.expand_as(input_ids) > (question_end_index + 1)).to(torch.bool) * ( attention_mask.expand_as(input_ids) < input_ids.shape[-1] - ).to(torch.uint8) + ).to(torch.bool) return attention_mask diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index ff90b9ac9a..fb3100f88b 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -666,7 +666,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): # add an extra bucket for padding tokens only num_buckets = num_buckets + 1 # assign padding tokens extra bucket - buckets_mask = attention_mask.to(torch.uint8)[:, None, None, :].expand(buckets.shape) + buckets_mask = attention_mask.to(torch.bool)[:, None, None, :].expand(buckets.shape) buckets = torch.where( buckets_mask, buckets, torch.tensor(num_buckets - 1, dtype=torch.long, device=buckets.device) ) @@ -841,7 +841,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): # attention mask for LSH if attention_mask is not None: # if chunked attention, the attention mask has to correspond to LSH order - attention_mask = attention_mask.to(torch.uint8)[:, None, :] + attention_mask = attention_mask.to(torch.bool)[:, None, :] if not do_standard_self_attention: # expand attn_mask to fit with key_value_bucket_idx shape attention_mask = attention_mask[:, None, :] @@ -1225,7 +1225,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): ): # chunk attention mask and look before and after if attention_mask is not None: - attention_mask = attention_mask.to(torch.uint8)[:, None, :] + attention_mask = attention_mask.to(torch.bool)[:, None, :] if not do_standard_self_attention: attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1) @@ -2159,8 +2159,8 @@ class ReformerModel(ReformerPreTrainedModel): else: attention_mask = torch.cat( [ - torch.ones(input_shape, device=device, dtype=torch.uint8), - torch.zeros((input_shape[0], padding_length), device=device, dtype=torch.uint8), + torch.ones(input_shape, device=device, dtype=torch.bool), + torch.zeros((input_shape[0], padding_length), device=device, dtype=torch.bool), ], dim=-1, ) diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 9ddfc7821f..5513fec19d 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -566,7 +566,7 @@ class XSoftmax(torch.autograd.Function): g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min)) ) output = softmax(g, output, dim) - return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8))) + return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.bool))) # Copied from transformers.models.deberta.modeling_deberta.DropoutContext diff --git a/src/transformers/models/transfo_xl/modeling_transfo_xl.py b/src/transformers/models/transfo_xl/modeling_transfo_xl.py index 094b2d33f6..e97091c35c 100644 --- a/src/transformers/models/transfo_xl/modeling_transfo_xl.py +++ b/src/transformers/models/transfo_xl/modeling_transfo_xl.py @@ -927,7 +927,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel): mlen = mems[0].size(0) if mems is not None else 0 klen = mlen + qlen if self.same_length: - all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8) + all_ones = word_emb.new_ones((qlen, klen), dtype=torch.bool) mask_len = klen - self.mem_len if mask_len > 0: mask_shift_len = qlen - mask_len @@ -935,7 +935,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel): mask_shift_len = qlen dec_attn_mask = (torch.triu(all_ones, 1 + mlen) + torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1 else: - dec_attn_mask = torch.triu(word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1 + mlen)[ + dec_attn_mask = torch.triu(word_emb.new_ones((qlen, klen), dtype=torch.bool), diagonal=1 + mlen)[ :, :, None ] diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index eddf503334..ed614abbc5 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -442,8 +442,11 @@ class ModelTesterMixin: # Before we test anything for key in model_fast_init.state_dict().keys(): - max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() - self.assertLessEqual(max_diff, 1e-5, msg=f"{key} not identical") + if isinstance(model_slow_init.state_dict()[key], torch.BoolTensor): + max_diff = (model_slow_init.state_dict()[key] ^ model_fast_init.state_dict()[key]).sum().item() + else: + max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") def test_save_load_fast_init_to_base(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -490,10 +493,15 @@ class ModelTesterMixin: model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False) for key in model_fast_init.state_dict().keys(): - max_diff = torch.max( - torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]) - ).item() - self.assertLessEqual(max_diff, 1e-5, msg=f"{key} not identical") + if isinstance(model_slow_init.state_dict()[key], torch.BoolTensor): + max_diff = torch.max( + model_slow_init.state_dict()[key] ^ model_fast_init.state_dict()[key] + ).item() + else: + max_diff = torch.max( + torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]) + ).item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()