cuda again

This commit is contained in:
thomwolf
2019-06-19 15:32:10 +02:00
parent 14f0e8e557
commit 909d4f1af2

View File

@@ -218,7 +218,7 @@ def run_model():
print_2d_tensor(head_importance) print_2d_tensor(head_importance)
logger.info("Head ranked by importance scores") logger.info("Head ranked by importance scores")
head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device) head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device)
head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(head_importance.numel()) head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(head_importance.numel(), device=args.device)
head_ranks = head_ranks.view_as(head_importance) head_ranks = head_ranks.view_as(head_importance)
print_2d_tensor(head_ranks) print_2d_tensor(head_ranks)