Replace legacy tensor.Tensor with torch.tensor/torch.empty (#12027)

* Replace legacy torch.Tensor constructor with torch.{tensor, empty}

* Remove torch.Tensor in examples
This commit is contained in:
Mario Šaško
2021-06-08 05:58:38 -07:00
committed by GitHub
parent e33085d648
commit f5eec0d8e9
13 changed files with 26 additions and 22 deletions

View File

@@ -1426,7 +1426,7 @@ class AnchorGenerator(nn.Module):
h = aspect_ratio * w
x0, y0, x1, y1 = -w / 2.0, -h / 2.0, w / 2.0, h / 2.0
anchors.append([x0, y0, x1, y1])
return nn.Parameter(torch.Tensor(anchors))
return nn.Parameter(torch.tensor(anchors))
def forward(self, features):
"""

View File

@@ -532,7 +532,7 @@ def load_frcnn_pkl_from_url(url):
for k, v in model.items():
new[k] = torch.from_numpy(v)
if "running_var" in k:
zero = torch.Tensor([0])
zero = torch.tensor([0])
k2 = k.replace("running_var", "num_batches_tracked")
new[k2] = zero
return new

View File

@@ -72,7 +72,7 @@ class MaskedLinear(nn.Linear):
if self.pruning_method in ["topK", "threshold", "sigmoied_threshold", "l0"]:
self.mask_scale = mask_scale
self.mask_init = mask_init
self.mask_scores = nn.Parameter(torch.Tensor(self.weight.size()))
self.mask_scores = nn.Parameter(torch.empty(self.weight.size()))
self.init_mask()
def init_mask(self):

View File

@@ -223,7 +223,7 @@ class GenerativeQAModule(BaseTransformer):
decoder_start_token_id = generator.config.decoder_start_token_id
decoder_input_ids = (
torch.cat(
[torch.Tensor([[decoder_start_token_id]] * target_ids.shape[0]).to(target_ids), target_ids],
[torch.tensor([[decoder_start_token_id]] * target_ids.shape[0]).to(target_ids), target_ids],
dim=1,
)
if target_ids.shape[0] < self.target_lens["train"]

View File

@@ -222,7 +222,7 @@ class GenerativeQAModule(BaseTransformer):
decoder_start_token_id = generator.config.decoder_start_token_id
decoder_input_ids = (
torch.cat(
[torch.Tensor([[decoder_start_token_id]] * target_ids.shape[0]).to(target_ids), target_ids],
[torch.tensor([[decoder_start_token_id]] * target_ids.shape[0]).to(target_ids), target_ids],
dim=1,
)
if target_ids.shape[0] < self.target_lens["train"]