From aa3778afc2ad681edcd253837246d5eb8c62291c Mon Sep 17 00:00:00 2001 From: cyyever Date: Fri, 28 Mar 2025 22:26:22 +0800 Subject: [PATCH] Change deprecated PT functions (#37041) Change deprecated functions --- src/transformers/models/clvp/modeling_clvp.py | 4 ++-- .../models/convnextv2/modeling_convnextv2.py | 2 +- .../models/deprecated/jukebox/modeling_jukebox.py | 14 +++++++++----- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/clvp/modeling_clvp.py b/src/transformers/models/clvp/modeling_clvp.py index 844ca354cd..a8feab0b1a 100644 --- a/src/transformers/models/clvp/modeling_clvp.py +++ b/src/transformers/models/clvp/modeling_clvp.py @@ -1346,8 +1346,8 @@ class ClvpForCausalLM(ClvpPreTrainedModel, GenerationMixin): if hasattr(model_kwargs, "attention_mask"): position_ids = model_kwargs["attention_mask"].long().cumsum(-1) - 1 else: - position_ids = torch.range( - 0, conditioning_embeds.shape[1] - 1, dtype=torch.long, device=conditioning_embeds.device + position_ids = torch.arange( + 0, conditioning_embeds.shape[1], dtype=torch.long, device=conditioning_embeds.device ) position_ids = position_ids.unsqueeze(0).repeat(conditioning_embeds.shape[0], 1) diff --git a/src/transformers/models/convnextv2/modeling_convnextv2.py b/src/transformers/models/convnextv2/modeling_convnextv2.py index c0490eead2..b779dfbe41 100644 --- a/src/transformers/models/convnextv2/modeling_convnextv2.py +++ b/src/transformers/models/convnextv2/modeling_convnextv2.py @@ -100,7 +100,7 @@ class ConvNextV2GRN(nn.Module): def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: # Compute and normalize global spatial feature maps - global_features = torch.norm(hidden_states, p=2, dim=(1, 2), keepdim=True) + global_features = torch.linalg.norm(hidden_states, ord=2, dim=(1, 2), keepdim=True) norm_features = global_features / (global_features.mean(dim=-1, keepdim=True) + 1e-6) hidden_states = self.weight * (hidden_states * norm_features) + self.bias + hidden_states diff --git a/src/transformers/models/deprecated/jukebox/modeling_jukebox.py b/src/transformers/models/deprecated/jukebox/modeling_jukebox.py index b5ab4cea1b..0fd1bad626 100755 --- a/src/transformers/models/deprecated/jukebox/modeling_jukebox.py +++ b/src/transformers/models/deprecated/jukebox/modeling_jukebox.py @@ -429,7 +429,7 @@ class JukeboxBottleneckBlock(nn.Module): entropy = -torch.sum(_codebook_prob * torch.log(_codebook_prob + 1e-8)) # entropy ie how diverse used_curr = (_codebook_elem >= self.threshold).sum() usage = torch.sum(usage) - dk = torch.norm(self.codebook - old_codebook) / np.sqrt(np.prod(old_codebook.shape)) + dk = torch.linalg.norm(self.codebook - old_codebook) / np.sqrt(np.prod(old_codebook.shape)) return {"entropy": entropy, "used_curr": used_curr, "usage": usage, "dk": dk} def preprocess(self, hidden_states): @@ -437,11 +437,13 @@ class JukeboxBottleneckBlock(nn.Module): hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) if hidden_states.shape[-1] == self.codebook_width: - prenorm = torch.norm(hidden_states - torch.mean(hidden_states)) / np.sqrt(np.prod(hidden_states.shape)) + prenorm = torch.linalg.norm(hidden_states - torch.mean(hidden_states)) / np.sqrt( + np.prod(hidden_states.shape) + ) elif hidden_states.shape[-1] == 2 * self.codebook_width: x1, x2 = hidden_states[..., : self.codebook_width], hidden_states[..., self.codebook_width :] - prenorm = (torch.norm(x1 - torch.mean(x1)) / np.sqrt(np.prod(x1.shape))) + ( - torch.norm(x2 - torch.mean(x2)) / np.sqrt(np.prod(x2.shape)) + prenorm = (torch.linalg.norm(x1 - torch.mean(x1)) / np.sqrt(np.prod(x1.shape))) + ( + torch.linalg.norm(x2 - torch.mean(x2)) / np.sqrt(np.prod(x2.shape)) ) # Normalise @@ -517,7 +519,9 @@ class JukeboxBottleneckBlock(nn.Module): update_metrics = {} # Loss - commit_loss = torch.norm(dequantised_states.detach() - hidden_states) ** 2 / np.prod(hidden_states.shape) + commit_loss = torch.linalg.norm(dequantised_states.detach() - hidden_states) ** 2 / np.prod( + hidden_states.shape + ) # Passthrough dequantised_states = hidden_states + (dequantised_states - hidden_states).detach()