Replace legacy tensor.Tensor with torch.tensor/torch.empty (#12027)
* Replace legacy torch.Tensor constructor with torch.{tensor, empty}
* Remove torch.Tensor in examples
This commit is contained in:
@@ -72,7 +72,7 @@ class MaskedLinear(nn.Linear):
|
||||
if self.pruning_method in ["topK", "threshold", "sigmoied_threshold", "l0"]:
|
||||
self.mask_scale = mask_scale
|
||||
self.mask_init = mask_init
|
||||
self.mask_scores = nn.Parameter(torch.Tensor(self.weight.size()))
|
||||
self.mask_scores = nn.Parameter(torch.empty(self.weight.size()))
|
||||
self.init_mask()
|
||||
|
||||
def init_mask(self):
|
||||
|
||||
Reference in New Issue
Block a user