cuda again
This commit is contained in:
@@ -218,7 +218,7 @@ def run_model():
|
|||||||
print_2d_tensor(head_importance)
|
print_2d_tensor(head_importance)
|
||||||
logger.info("Head ranked by importance scores")
|
logger.info("Head ranked by importance scores")
|
||||||
head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device)
|
head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device)
|
||||||
head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(head_importance.numel())
|
head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(head_importance.numel(), device=args.device)
|
||||||
head_ranks = head_ranks.view_as(head_importance)
|
head_ranks = head_ranks.view_as(head_importance)
|
||||||
print_2d_tensor(head_ranks)
|
print_2d_tensor(head_ranks)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user