From 764ab0d46aecbb82d6d16c847d0df88397c7d780 Mon Sep 17 00:00:00 2001 From: cyyever Date: Wed, 2 Apr 2025 21:15:23 +0800 Subject: [PATCH] Merge tensor operations with device transfer operations (#37097) * Merge operations with to Signed-off-by: cyy * Use dtype Signed-off-by: cyy --------- Signed-off-by: cyy --- src/transformers/generation/beam_search.py | 6 +-- .../generation/candidate_generator.py | 4 +- src/transformers/generation/logits_process.py | 2 +- src/transformers/generation/utils.py | 2 +- src/transformers/modeling_rope_utils.py | 8 ++-- src/transformers/models/aria/modeling_aria.py | 4 +- .../models/bamba/modeling_bamba.py | 4 +- src/transformers/models/bark/modeling_bark.py | 8 ++-- src/transformers/models/blip/modeling_blip.py | 8 ++-- .../models/blip_2/modeling_blip_2.py | 2 +- src/transformers/models/bros/modeling_bros.py | 6 +-- .../models/chameleon/modeling_chameleon.py | 8 +++- .../models/deberta_v2/modeling_deberta_v2.py | 2 +- .../modeling_decision_transformer.py | 2 +- .../deepseek_v3/modeling_deepseek_v3.py | 4 +- .../image_processing_deformable_detr_fast.py | 2 +- .../open_llama/modeling_open_llama.py | 10 ++++- .../models/detr/image_processing_detr_fast.py | 2 +- .../models/diffllama/modeling_diffllama.py | 4 +- src/transformers/models/emu3/modeling_emu3.py | 4 +- .../models/falcon/modeling_falcon.py | 4 +- .../models/gemma/modeling_gemma.py | 4 +- .../models/gemma2/modeling_gemma2.py | 4 +- .../models/gemma3/modeling_gemma3.py | 4 +- src/transformers/models/glm/modeling_glm.py | 4 +- src/transformers/models/gpt2/modeling_gpt2.py | 2 +- .../models/gpt_neo/modeling_gpt_neo.py | 2 +- .../models/gpt_neox/modeling_gpt_neox.py | 4 +- .../modeling_gpt_neox_japanese.py | 4 +- .../models/granite/modeling_granite.py | 4 +- .../models/granitemoe/modeling_granitemoe.py | 4 +- .../modeling_granitemoeshared.py | 4 +- .../models/helium/modeling_helium.py | 4 +- .../models/ibert/quant_modules.py | 2 +- .../models/idefics/modeling_idefics.py | 5 ++- .../models/imagegpt/modeling_imagegpt.py | 4 +- .../models/jetmoe/modeling_jetmoe.py | 4 +- .../models/llama/modeling_llama.py | 4 +- src/transformers/models/mimi/modeling_mimi.py | 4 +- .../models/mistral/modeling_mistral.py | 4 +- .../models/mixtral/modeling_mixtral.py | 4 +- .../models/modernbert/modeling_modernbert.py | 4 +- .../models/moonshine/modeling_moonshine.py | 4 +- .../models/moshi/modeling_moshi.py | 4 +- .../models/nemotron/modeling_nemotron.py | 4 +- src/transformers/models/olmo/modeling_olmo.py | 4 +- .../models/olmo2/modeling_olmo2.py | 4 +- .../models/olmoe/modeling_olmoe.py | 4 +- .../omdet_turbo/modeling_omdet_turbo.py | 6 ++- .../models/persimmon/modeling_persimmon.py | 4 +- src/transformers/models/phi/modeling_phi.py | 4 +- .../models/qwen2/modeling_qwen2.py | 4 +- .../models/qwen2_moe/modeling_qwen2_moe.py | 4 +- .../models/qwen3/modeling_qwen3.py | 4 +- .../models/qwen3_moe/modeling_qwen3_moe.py | 4 +- .../seamless_m4t/modeling_seamless_m4t.py | 43 ++++++++++--------- .../modeling_seamless_m4t_v2.py | 22 +++++----- .../models/seggpt/image_processing_seggpt.py | 2 +- .../models/sew_d/modeling_sew_d.py | 2 +- .../models/speecht5/modeling_speecht5.py | 2 +- .../models/stablelm/modeling_stablelm.py | 4 +- .../models/starcoder2/modeling_starcoder2.py | 4 +- .../models/tapas/modeling_tapas.py | 4 +- .../models/zamba2/modeling_zamba2.py | 4 +- src/transformers/trainer.py | 2 +- src/transformers/trainer_pt_utils.py | 2 +- src/transformers/utils/import_utils.py | 2 +- 67 files changed, 209 insertions(+), 113 deletions(-) diff --git a/src/transformers/generation/beam_search.py b/src/transformers/generation/beam_search.py index c57999ba21..3938deb482 100644 --- a/src/transformers/generation/beam_search.py +++ b/src/transformers/generation/beam_search.py @@ -724,7 +724,7 @@ class ConstrainedBeamSearchScorer(BeamScorer): advance_state.reset(pre_seq.tolist()) if not advance_state.completed: - advance_tokens = torch.LongTensor(advance_state.advance()).to(device) + advance_tokens = torch.tensor(advance_state.advance(), dtype=torch.long, device=device) for advance_token in advance_tokens: # since adding each `advance_token` leads to a different hypothesis, create new state instance. new_state = advance_state.copy(stateful=True) @@ -775,14 +775,14 @@ class ConstrainedBeamSearchScorer(BeamScorer): track_new["new_states"].append(advance_state) if len(track_new["new_indices"]) > 0: - new_indices = torch.tensor(track_new["new_indices"]).to(device) + new_indices = torch.tensor(track_new["new_indices"], device=device) new_tokens = torch.stack(track_new["new_tokens"]).to(device) new_scores = torch.stack(track_new["new_scores"]).to(device) all_states = topk_contraint_states + track_new["new_states"] all_tokens = torch.cat((sent_beam_tokens, new_tokens), -1) all_scores = torch.cat((sent_beam_scores, new_scores), -1) - all_banks = torch.tensor([one.get_bank() for one in all_states]).to(device) + all_banks = torch.tensor([one.get_bank() for one in all_states], device=device) zipped = all_banks * 100 + all_scores indices = zipped.sort(descending=True).indices diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index c2a904238a..fe57f532e6 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -719,7 +719,9 @@ class AssistantToTargetTranslator: """ target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size) - target_logits: torch.FloatTensor = torch.full(target_shape, self.FILTER_VALUE).to(self._assistant_model_device) + target_logits: torch.FloatTensor = torch.full( + target_shape, self.FILTER_VALUE, device=self._assistant_model_device + ) # Mask for valid indices assistant_indices_mask = self._assistant_to_target_input_ids != self.SUPPRESS_TOKEN_ID # Exclude invalid indices diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index df51419970..16c04478f0 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1157,7 +1157,7 @@ class SequenceBiasLogitsProcessor(LogitsProcessor): # Precompute the bias tensors to be applied. Sequences of length 1 are kept separately, as they can be applied # with simpler logic. - self.length_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float).to(scores.device) + self.length_1_bias = torch.zeros((vocabulary_size,), dtype=torch.float, device=scores.device) for sequence_ids, bias in self.sequence_bias.items(): if len(sequence_ids) == 1: self.length_1_bias[sequence_ids[-1]] = bias diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index bfb404be95..232cceeedf 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2599,7 +2599,7 @@ class GenerationMixin: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device) + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0, device=device) # send 0.0 if we finished, 1.0 otherwise dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) # did all peers finish? the reduced sum will be 0.0 then diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index 7f7fbb3df4..8cf2c3bd3e 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -64,7 +64,7 @@ def _compute_default_rope_parameters( attention_factor = 1.0 # Unused in this type of RoPE # Compute the inverse frequencies - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)) return inv_freq, attention_factor @@ -156,7 +156,7 @@ def _compute_dynamic_ntk_parameters( # Compute the inverse frequencies base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)) return inv_freq, attention_factor @@ -241,14 +241,14 @@ def _compute_yarn_parameters( # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs # to expand the possible context length. In other words, interpolation = apply scaling factor. - pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim) + pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (factor * pos_freqs) low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings) # Get n-dimensional rotational scaling corrected for extrapolation - inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device) + inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float) inv_freq = ( inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + inv_freq_extrapolation * inv_freq_extrapolation_factor diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 5c3b50caef..35c1730f1e 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -783,7 +783,9 @@ class AriaTextRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 953e1242c5..5af06951c6 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -173,7 +173,9 @@ class BambaRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index ccfae95058..e5b7c0a49a 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -897,7 +897,7 @@ class BarkSemanticModel(BarkCausalModel): # pass input_ids in order to stay consistent with the transformers generate method even though it is not used # (except to get the input seq_len - that's why we keep the first 257 tokens) semantic_output = super().generate( - torch.ones((batch_size, max_input_semantic_length + 1), dtype=torch.int).to(self.device), + torch.ones((batch_size, max_input_semantic_length + 1), dtype=torch.int, device=self.device), input_embeds=input_embeds, logits_processor=[suppress_tokens_logits_processor, early_stopping_logits_processor], generation_config=semantic_generation_config, @@ -989,8 +989,8 @@ class BarkCoarseModel(BarkCausalModel): else: # shape: (batch_size, 0) - x_semantic_history = torch.tensor([[]] * batch_size, dtype=torch.int).to(self.device) - x_coarse_history = torch.tensor([[]] * batch_size, dtype=torch.int).to(self.device) + x_semantic_history = torch.tensor([[]] * batch_size, dtype=torch.int, device=self.device) + x_coarse_history = torch.tensor([[]] * batch_size, dtype=torch.int, device=self.device) return x_semantic_history, x_coarse_history @@ -1097,7 +1097,7 @@ class BarkCoarseModel(BarkCausalModel): input_coarse = torch.hstack( [ input_coarse, - torch.tensor([[coarse_generation_config.coarse_infer_token]] * batch_size).to(self.device), + torch.tensor([[coarse_generation_config.coarse_infer_token]] * batch_size, device=self.device), x_coarse[:, -max_coarse_history:], ] ) diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index 2a7430b937..dd9b57973c 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -1198,7 +1198,7 @@ class BlipForConditionalGeneration(BlipPreTrainedModel, GenerationMixin): image_embeds = vision_outputs[0] - image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device) + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) if isinstance(input_ids, list): input_ids = torch.LongTensor(input_ids) @@ -1424,7 +1424,7 @@ class BlipForQuestionAnswering(BlipPreTrainedModel, GenerationMixin): image_embeds = vision_outputs[0] - image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device) + image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) if isinstance(input_ids, list): input_ids = torch.LongTensor(input_ids) @@ -1439,7 +1439,9 @@ class BlipForQuestionAnswering(BlipPreTrainedModel, GenerationMixin): question_embeds = question_outputs[0] - question_attention_mask = torch.ones(question_embeds.size()[:-1], dtype=torch.long).to(question_embeds.device) + question_attention_mask = torch.ones( + question_embeds.size()[:-1], dtype=torch.long, device=question_embeds.device + ) bos_ids = torch.full( (question_embeds.size(0), 1), fill_value=self.decoder_start_token_id, device=question_embeds.device diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index de15c0d1ed..a02868359e 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -2498,7 +2498,7 @@ class Blip2ForImageTextRetrieval(Blip2PreTrainedModel): if use_image_text_matching_head: query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) - query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(query_tokens.device) + query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=query_tokens.device) attention_mask = torch.cat([query_attention_mask, attention_mask], dim=1) query_embeds = self.embeddings( diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py index 0e1e86c0b3..ee278631f2 100755 --- a/src/transformers/models/bros/modeling_bros.py +++ b/src/transformers/models/bros/modeling_bros.py @@ -1158,7 +1158,7 @@ class BrosSpadeEEForTokenClassification(BrosPreTrainedModel): subsequent_token_logits = subsequent_token_logits.masked_fill( invalid_token_mask[:, None, :], torch.finfo(subsequent_token_logits.dtype).min ) - self_token_mask = torch.eye(max_seq_length, max_seq_length + 1).to(device).bool() + self_token_mask = torch.eye(max_seq_length, max_seq_length + 1).to(device=device, dtype=torch.bool) subsequent_token_logits = subsequent_token_logits.masked_fill( self_token_mask[None, :, :], torch.finfo(subsequent_token_logits.dtype).min ) @@ -1287,13 +1287,13 @@ class BrosSpadeELForTokenClassification(BrosPreTrainedModel): batch_size, max_seq_length = attention_mask.shape device = attention_mask.device - self_token_mask = torch.eye(max_seq_length, max_seq_length + 1).to(device).bool() + self_token_mask = torch.eye(max_seq_length, max_seq_length + 1).to(device=device, dtype=torch.bool) mask = bbox_first_token_mask.view(-1) bbox_first_token_mask = torch.cat( [ ~bbox_first_token_mask, - torch.zeros([batch_size, 1], dtype=torch.bool).to(device), + torch.zeros([batch_size, 1], dtype=torch.bool, device=device), ], axis=1, ) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index aff97487bb..7ccc660aac 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -95,7 +95,10 @@ class ChameleonRotaryEmbedding(nn.Module): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + inv_freq = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim) + ) self.register_buffer("inv_freq", inv_freq, persistent=False) # For BC we register cos and sin cached self.max_seq_len_cached = max_position_embeddings @@ -138,7 +141,8 @@ class ChameleonDynamicNTKScalingRotaryEmbedding(ChameleonRotaryEmbedding): (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) inv_freq = 1.0 / ( - base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) + base + ** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=x.device, dtype=torch.float) / self.dim) ) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index b02628ed69..e9a4bbf8ae 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -300,7 +300,7 @@ class DisentangledSelfAttention(nn.Module): raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}") att_span = self.pos_ebd_size - relative_pos = relative_pos.long().to(query_layer.device) + relative_pos = relative_pos.to(device=query_layer.device, dtype=torch.long) rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0) if self.share_att_key: diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 08c0f918c4..faf24729d6 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -233,7 +233,7 @@ class DecisionTransformerGPT2Attention(nn.Module): mask_value = torch.finfo(attn_weights.dtype).min # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device) attn_weights = torch.where(causal_mask, attn_weights, mask_value) if attention_mask is not None: diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 67564fbca4..5a09e85779 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -113,7 +113,9 @@ class DeepseekV3RotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/deformable_detr/image_processing_deformable_detr_fast.py b/src/transformers/models/deformable_detr/image_processing_deformable_detr_fast.py index 5d400f3d13..27c9aaa371 100644 --- a/src/transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +++ b/src/transformers/models/deformable_detr/image_processing_deformable_detr_fast.py @@ -237,7 +237,7 @@ def prepare_coco_panoptic_annotation( new_target["orig_size"] = torch.as_tensor([image_height, image_width], dtype=torch.int64, device=image.device) if "segments_info" in target: - masks = read_image(annotation_path).permute(1, 2, 0).to(torch.int32).to(image.device) + masks = read_image(annotation_path).permute(1, 2, 0).to(dtype=torch.int32, device=image.device) masks = rgb_to_id(masks) ids = torch.as_tensor([segment_info["id"] for segment_info in target["segments_info"]], device=image.device) diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 2a74517d9f..98bc7fb70a 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -73,7 +73,10 @@ class OpenLlamaRotaryEmbedding(nn.Module): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + inv_freq = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim) + ) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. @@ -135,7 +138,10 @@ class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): base = self.base * ( (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + inv_freq = 1.0 / ( + base + ** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim) + ) self.register_buffer("inv_freq", inv_freq, persistent=False) t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) diff --git a/src/transformers/models/detr/image_processing_detr_fast.py b/src/transformers/models/detr/image_processing_detr_fast.py index 28a8929140..16bef79e59 100644 --- a/src/transformers/models/detr/image_processing_detr_fast.py +++ b/src/transformers/models/detr/image_processing_detr_fast.py @@ -254,7 +254,7 @@ def prepare_coco_panoptic_annotation( new_target["orig_size"] = torch.as_tensor([image_height, image_width], dtype=torch.int64, device=image.device) if "segments_info" in target: - masks = read_image(annotation_path).permute(1, 2, 0).to(torch.int32).to(image.device) + masks = read_image(annotation_path).permute(1, 2, 0).to(dtype=torch.int32, device=image.device) masks = rgb_to_id(masks) ids = torch.as_tensor([segment_info["id"] for segment_info in target["segments_info"]], device=image.device) diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 933cf15963..b25c19384e 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -675,7 +675,9 @@ class DiffLlamaRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 3322ce28f9..e013e86632 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1256,7 +1256,9 @@ class Emu3RotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index e6cb36353d..50638858de 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -158,7 +158,9 @@ class FalconRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index dfa7aabfcf..31df29e6a0 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -149,7 +149,9 @@ class GemmaRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 75e318009c..556849d0bc 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -396,7 +396,9 @@ class Gemma2RotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index e17009f402..1078a01d93 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -191,7 +191,9 @@ class Gemma3RotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 8bd9031127..7d77f3f2f1 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -313,7 +313,9 @@ class GlmRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 0c045094e7..1af1366925 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -243,7 +243,7 @@ class GPT2Attention(nn.Module): mask_value = torch.finfo(attn_weights.dtype).min # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device) attn_weights = torch.where(causal_mask, attn_weights, mask_value) if attention_mask is not None: diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 12e45d918d..9cc18c0ea9 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -219,7 +219,7 @@ class GPTNeoSelfAttention(nn.Module): mask_value = torch.finfo(attn_weights.dtype).min # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device) attn_weights = torch.where(causal_mask, attn_weights, mask_value) if attention_mask is not None: # no matter the length, we just slice it diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 11dfbab89d..220e0b6e72 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -337,7 +337,9 @@ class GPTNeoXRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index aab7183ccf..26ee5e392a 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -281,7 +281,9 @@ class GPTNeoXJapaneseRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 6b64e18aa7..74bb0d054f 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -367,7 +367,9 @@ class GraniteRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index c05a105040..39441473cb 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -211,7 +211,9 @@ class GraniteMoeRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index fd2cdf2843..29d1b598f4 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -801,7 +801,9 @@ class GraniteMoeSharedRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 5e66ce7298..6b786f656c 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -132,7 +132,9 @@ class HeliumRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/ibert/quant_modules.py b/src/transformers/models/ibert/quant_modules.py index d490d555a7..949702a5af 100644 --- a/src/transformers/models/ibert/quant_modules.py +++ b/src/transformers/models/ibert/quant_modules.py @@ -651,7 +651,7 @@ class SymmetricQuantFunction(Function): Returns: `torch.Tensor`: Symmetric-quantized value of *input*. """ - zero_point = torch.tensor(0.0).to(scale.device) + zero_point = torch.tensor(0.0, device=scale.device) n = 2 ** (k - 1) - 1 new_quant_x = linear_quantize(x, scale, zero_point, inplace=False) diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 4255a1ebb3..9c21213b0b 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -415,7 +415,10 @@ class IdeficsEmbedding(torch.nn.Module): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + inv_freq = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim) + ) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 2e1f70a549..f75e24852b 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -247,7 +247,7 @@ class ImageGPTAttention(nn.Module): mask_value = torch.finfo(attn_weights.dtype).min # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device) attn_weights = torch.where(causal_mask, attn_weights, mask_value) if attention_mask is not None: @@ -297,7 +297,7 @@ class ImageGPTAttention(nn.Module): mask_value = torch.finfo(attn_weights.dtype).min # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype, device=attn_weights.device) attn_weights = torch.where(causal_mask, attn_weights, mask_value) if attention_mask is not None: diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index b87ca1376e..ae2bb44bf9 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -443,7 +443,9 @@ class JetMoeRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index f7d8c714d8..d64dd9b7b6 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -138,7 +138,9 @@ class LlamaRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 4746755a4f..b76cebc188 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -412,7 +412,9 @@ class MimiRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 65044b4dbe..b22e92f6f6 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -321,7 +321,9 @@ class MistralRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 7cfc266fd3..973fff5f17 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -443,7 +443,9 @@ class MixtralRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 28ba79b830..4f087ec382 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -295,7 +295,9 @@ class ModernBertRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 80742db812..040358ad46 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -366,7 +366,9 @@ class MoonshineRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index cc352e76fb..fcd4aea931 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -356,7 +356,9 @@ class MoshiRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 6a362804a2..d9cd4df990 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -143,7 +143,9 @@ class NemotronRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 79b8c28ca2..749b729b12 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -332,7 +332,9 @@ class OlmoRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index cc83dd4d17..35f0376f2d 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -333,7 +333,9 @@ class Olmo2RotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index e6246ac600..9589e4dd7a 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -206,7 +206,9 @@ class OlmoeRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py index 570c8cc3a3..e95e4a522b 100644 --- a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py +++ b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py @@ -1315,7 +1315,7 @@ class OmDetTurboDecoder(OmDetTurboPreTrainedModel): # [batch_size, height*width, channels] new_vision_features = torch.cat(new_vision_features, 1) - new_vision_shapes = torch.tensor(new_vision_shapes_list, dtype=torch.int64).to(vision_features[0].device) + new_vision_shapes = torch.tensor(new_vision_shapes_list, dtype=torch.int64, device=vision_features[0].device) level_start_index = torch.cat((new_vision_shapes.new_zeros((1,)), new_vision_shapes.prod(1).cumsum(0)[:-1])) return new_vision_features, new_vision_shapes, new_vision_shapes_list, level_start_index @@ -1330,7 +1330,9 @@ class OmDetTurboDecoder(OmDetTurboPreTrainedModel): ) predicted_class_features = self.encoder_vision_features( torch.where( - valid_mask, vision_features, torch.tensor(0.0, dtype=vision_features.dtype).to(vision_features.device) + valid_mask, + vision_features, + torch.tensor(0.0, dtype=vision_features.dtype, device=vision_features.device), ) ) diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 1e3a784026..5c9c62cc46 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -113,7 +113,9 @@ class PersimmonRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 612bd70407..539be9216d 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -328,7 +328,9 @@ class PhiRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 232598b231..5d5464dcae 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -334,7 +334,9 @@ class Qwen2RotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 8e1e8de79e..5e5be0a8f1 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -216,7 +216,9 @@ class Qwen2MoeRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index b235c8acde..0559ec789b 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -361,7 +361,9 @@ class Qwen3RotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 3897cd44bd..aa1418bd1c 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -456,7 +456,9 @@ class Qwen3MoeRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 372010a428..b4ebc88b0e 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -2873,7 +2873,7 @@ class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel, GenerationMixin): ) # tgt_lang gets priority over decoder input ids text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) - text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device) else: raise ValueError( """This model generation config doesn't have a `text_decoder_lang_to_code_id` key which maps @@ -3144,7 +3144,7 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel, GenerationMixin): ) # tgt_lang gets priority over decoder input ids text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) - text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device) else: raise ValueError( """This model generation config doesn't have a `text_decoder_lang_to_code_id` key which maps @@ -3420,7 +3420,7 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin): # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) - text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device) kwargs_text["decoder_input_ids"] = text_decoder_input_ids @@ -3441,7 +3441,8 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin): idx_most_probable_sequences_per_batch = text_generation_output.sequences_scores.view(batch_size, -1) idx_most_probable_sequences_per_batch = idx_most_probable_sequences_per_batch.argmax(-1) idx_most_probable_sequences_per_batch = ( - idx_most_probable_sequences_per_batch + torch.arange(batch_size).to(self.device) * num_return_sequences + idx_most_probable_sequences_per_batch + + torch.arange(batch_size, device=self.device) * num_return_sequences ) sequences = sequences[idx_most_probable_sequences_per_batch] @@ -3462,8 +3463,8 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin): # Compute t2u decoder_input_ids t2u_decoder_input_ids = kwargs_speech.get("decoder_input_ids") t2u_tgt_lang_id = self.generation_config.t2u_lang_code_to_id.get(tgt_lang) - t2u_decoder_input_ids = torch.tensor([[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size).to( - self.device + t2u_decoder_input_ids = torch.tensor( + [[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size, device=self.device ) kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids # second generation @@ -3480,9 +3481,9 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin): ) vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang) - vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids)).to(self.device) + vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device) - spkr_id = torch.tensor([[spkr_id]] * len(unit_ids)).to(self.device) + spkr_id = torch.tensor([[spkr_id]] * len(unit_ids), device=self.device) waveform, waveform_lengths = self.vocoder(input_ids=unit_ids, spkr_id=spkr_id, lang_id=vocoder_tgt_lang_id) @@ -3748,7 +3749,7 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin): text_decoder_input_ids = kwargs_text.get("decoder_input_ids") # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) - text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device) kwargs_text["decoder_input_ids"] = text_decoder_input_ids @@ -3779,7 +3780,8 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin): idx_most_probable_sequences_per_batch = text_generation_output.sequences_scores.view(batch_size, -1) idx_most_probable_sequences_per_batch = idx_most_probable_sequences_per_batch.argmax(-1) idx_most_probable_sequences_per_batch = ( - idx_most_probable_sequences_per_batch + torch.arange(batch_size).to(self.device) * num_return_sequences + idx_most_probable_sequences_per_batch + + torch.arange(batch_size, device=self.device) * num_return_sequences ) sequences = sequences[idx_most_probable_sequences_per_batch] @@ -3800,8 +3802,8 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin): # Compute t2u decoder_input_ids t2u_decoder_input_ids = kwargs_speech.get("decoder_input_ids") t2u_tgt_lang_id = self.generation_config.t2u_lang_code_to_id.get(tgt_lang) - t2u_decoder_input_ids = torch.tensor([[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size).to( - self.device + t2u_decoder_input_ids = torch.tensor( + [[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size, device=self.device ) kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids @@ -3819,9 +3821,9 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin): ) vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang) - vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids)).to(self.device) + vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device) - spkr_id = torch.tensor([[spkr_id]] * len(unit_ids)).to(self.device) + spkr_id = torch.tensor([[spkr_id]] * len(unit_ids), device=self.device) waveform, waveform_lengths = self.vocoder(input_ids=unit_ids, spkr_id=spkr_id, lang_id=vocoder_tgt_lang_id) @@ -4171,7 +4173,7 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin): if tgt_lang is not None: # tgt_lang gets priority over decoder input ids text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) - text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device) kwargs_text["decoder_input_ids"] = text_decoder_input_ids @@ -4221,7 +4223,8 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin): idx_most_probable_sequences_per_batch = text_generation_output.sequences_scores.view(batch_size, -1) idx_most_probable_sequences_per_batch = idx_most_probable_sequences_per_batch.argmax(-1) idx_most_probable_sequences_per_batch = ( - idx_most_probable_sequences_per_batch + torch.arange(batch_size).to(self.device) * num_return_sequences + idx_most_probable_sequences_per_batch + + torch.arange(batch_size, device=self.device) * num_return_sequences ) sequences = sequences[idx_most_probable_sequences_per_batch] @@ -4242,8 +4245,8 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin): # Compute t2u decoder_input_ids t2u_decoder_input_ids = kwargs_speech.get("decoder_input_ids") t2u_tgt_lang_id = self.generation_config.t2u_lang_code_to_id.get(tgt_lang) - t2u_decoder_input_ids = torch.tensor([[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size).to( - self.device + t2u_decoder_input_ids = torch.tensor( + [[self.config.t2u_eos_token_id, t2u_tgt_lang_id]] * batch_size, device=self.device ) kwargs_speech["decoder_input_ids"] = t2u_decoder_input_ids @@ -4261,9 +4264,9 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin): ) vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang) - vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids)).to(self.device) + vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device) - spkr_id = torch.tensor([[spkr_id]] * len(unit_ids)).to(self.device) + spkr_id = torch.tensor([[spkr_id]] * len(unit_ids), device=self.device) waveform, waveform_lengths = self.vocoder(input_ids=unit_ids, spkr_id=spkr_id, lang_id=vocoder_tgt_lang_id) diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index ae191b311e..7a36633db4 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -3153,7 +3153,7 @@ class SeamlessM4Tv2ForTextToText(SeamlessM4Tv2PreTrainedModel, GenerationMixin): ) # tgt_lang gets priority over decoder input ids text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) - text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device) else: raise ValueError( """This model generation config doesn't have a `text_decoder_lang_to_code_id` key which maps @@ -3434,7 +3434,7 @@ class SeamlessM4Tv2ForSpeechToText(SeamlessM4Tv2PreTrainedModel, GenerationMixin ) # tgt_lang gets priority over decoder input ids text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) - text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device) else: raise ValueError( """This model generation config doesn't have a `text_decoder_lang_to_code_id` key which maps @@ -3720,7 +3720,7 @@ class SeamlessM4Tv2ForTextToSpeech(SeamlessM4Tv2PreTrainedModel, GenerationMixin # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) - text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device) kwargs_text["decoder_input_ids"] = text_decoder_input_ids @@ -3810,9 +3810,9 @@ class SeamlessM4Tv2ForTextToSpeech(SeamlessM4Tv2PreTrainedModel, GenerationMixin ) vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang) - vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids)).to(self.device) + vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device) - speaker_id = torch.tensor([[speaker_id]] * len(unit_ids)).to(self.device) + speaker_id = torch.tensor([[speaker_id]] * len(unit_ids), device=self.device) waveform, waveform_lengths = self.vocoder( input_ids=unit_ids, speaker_id=speaker_id, lang_id=vocoder_tgt_lang_id @@ -4090,7 +4090,7 @@ class SeamlessM4Tv2ForSpeechToSpeech(SeamlessM4Tv2PreTrainedModel, GenerationMix text_decoder_input_ids = kwargs_text.get("decoder_input_ids") # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) - text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device) kwargs_text["decoder_input_ids"] = text_decoder_input_ids @@ -4190,9 +4190,9 @@ class SeamlessM4Tv2ForSpeechToSpeech(SeamlessM4Tv2PreTrainedModel, GenerationMix ) vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang) - vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids)).to(self.device) + vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device) - speaker_id = torch.tensor([[speaker_id]] * len(unit_ids)).to(self.device) + speaker_id = torch.tensor([[speaker_id]] * len(unit_ids), device=self.device) waveform, waveform_lengths = self.vocoder( input_ids=unit_ids, speaker_id=speaker_id, lang_id=vocoder_tgt_lang_id @@ -4559,7 +4559,7 @@ class SeamlessM4Tv2Model(SeamlessM4Tv2PreTrainedModel, GenerationMixin): if tgt_lang is not None: # tgt_lang gets priority over decoder input ids text_tgt_lang_id = self.generation_config.text_decoder_lang_to_code_id.get(tgt_lang) - text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size).to(self.device) + text_decoder_input_ids = torch.tensor([[text_tgt_lang_id]] * batch_size, device=self.device) kwargs_text["decoder_input_ids"] = text_decoder_input_ids @@ -4679,9 +4679,9 @@ class SeamlessM4Tv2Model(SeamlessM4Tv2PreTrainedModel, GenerationMixin): ) vocoder_tgt_lang_id = self.generation_config.vocoder_lang_code_to_id.get(tgt_lang) - vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids)).to(self.device) + vocoder_tgt_lang_id = torch.tensor([[vocoder_tgt_lang_id]] * len(unit_ids), device=self.device) - speaker_id = torch.tensor([[speaker_id]] * len(unit_ids)).to(self.device) + speaker_id = torch.tensor([[speaker_id]] * len(unit_ids), device=self.device) waveform, waveform_lengths = self.vocoder( input_ids=unit_ids, speaker_id=speaker_id, lang_id=vocoder_tgt_lang_id diff --git a/src/transformers/models/seggpt/image_processing_seggpt.py b/src/transformers/models/seggpt/image_processing_seggpt.py index bcc1ad32ef..26c7c1f47a 100644 --- a/src/transformers/models/seggpt/image_processing_seggpt.py +++ b/src/transformers/models/seggpt/image_processing_seggpt.py @@ -586,7 +586,7 @@ class SegGptImageProcessor(BaseImageProcessor): palette_tensor = None palette = self.get_palette(num_labels) if num_labels is not None else None if palette is not None: - palette_tensor = torch.tensor(palette).float().to(masks.device) + palette_tensor = torch.tensor(palette).to(device=masks.device, dtype=torch.float) _, num_channels, _, _ = masks.shape palette_tensor = palette_tensor.view(1, 1, num_labels + 1, num_channels) diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index aab4850a12..9f49a46a05 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -820,7 +820,7 @@ class DisentangledSelfAttention(nn.Module): raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}") att_span = self.pos_ebd_size - relative_pos = relative_pos.long().to(query_layer.device) + relative_pos = relative_pos.to(device=query_layer.device, dtype=torch.long) rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0) if self.share_att_key: diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 617fb9b592..d85e52924a 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -431,7 +431,7 @@ class SpeechT5RelativePositionalEncoding(torch.nn.Module): def forward(self, hidden_states): seq_len = hidden_states.shape[1] - pos_seq = torch.arange(0, seq_len).long().to(hidden_states.device) + pos_seq = torch.arange(0, seq_len).to(device=hidden_states.device, dtype=torch.long) pos_seq = pos_seq[:, None] - pos_seq[None, :] pos_seq[pos_seq < -self.max_length] = -self.max_length diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index ab14bdb8e6..8d2b3f96ec 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -118,7 +118,9 @@ class StableLmRotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 364f9a3d03..6fbb652c99 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -325,7 +325,9 @@ class Starcoder2RotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 5a2450b9a8..95b097013e 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -1243,8 +1243,8 @@ class TapasForQuestionAnswering(TapasPreTrainedModel): if table_mask is None: table_mask = torch.where(row_ids > 0, torch.ones_like(row_ids), torch.zeros_like(row_ids)) # torch.FloatTensor[batch_size, seq_length] - input_mask_float = attention_mask.float().to(device) - table_mask_float = table_mask.float().to(device) + input_mask_float = attention_mask.to(device=device, dtype=torch.float) + table_mask_float = table_mask.to(device=device, dtype=torch.float) # Mask for cells that exist in the table (i.e. that are not padding). cell_mask, _ = reduce_mean(input_mask_float, cell_index) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 05502d0e4e..2c60cc2276 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -270,7 +270,9 @@ class Zamba2RotaryEmbedding(nn.Module): device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()).transpose(1, 2) + freqs = ( + inv_freq_expanded.to(device=x.device, dtype=torch.float) @ position_ids_expanded.float() + ).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a00eec8b82..5b238e610e 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2454,7 +2454,7 @@ class Trainer: self.state.init_training_references(self, max_steps, num_train_epochs, trial) # tr_loss is a tensor to avoid synchronization of TPUs through .item() - tr_loss = torch.tensor(0.0).to(args.device) + tr_loss = torch.tensor(0.0, device=args.device) # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses self._total_loss_scalar = 0.0 self._globalstep_last_logged = self.state.global_step diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 6cd9815afb..30474daea6 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -225,7 +225,7 @@ def distributed_broadcast_scalars( device: Optional[torch.device] = torch.device("cuda"), ) -> torch.Tensor: try: - tensorized_scalar = torch.tensor(scalars).to(device) + tensorized_scalar = torch.tensor(scalars, device=device) output_tensors = [tensorized_scalar.clone() for _ in range(dist.get_world_size())] dist.all_gather(output_tensors, tensorized_scalar) concat = torch.cat(output_tensors, dim=0) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 1ac109c477..5a618f901c 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -596,7 +596,7 @@ def is_torch_bf16_available_on_device(device): return True try: - x = torch.zeros(2, 2, dtype=torch.bfloat16).to(device) + x = torch.zeros(2, 2, dtype=torch.bfloat16, device=device) _ = x @ x except: # noqa: E722 # TODO: more precise exception matching, if possible.