From 909d4f1af24d354a6e1a45ca6cf58d30fc0fcc07 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 19 Jun 2019 15:32:10 +0200 Subject: [PATCH] cuda again --- examples/bertology.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/bertology.py b/examples/bertology.py index 6997b9e26d..694328162a 100644 --- a/examples/bertology.py +++ b/examples/bertology.py @@ -218,7 +218,7 @@ def run_model(): print_2d_tensor(head_importance) logger.info("Head ranked by importance scores") head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device) - head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(head_importance.numel()) + head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(head_importance.numel(), device=args.device) head_ranks = head_ranks.view_as(head_importance) print_2d_tensor(head_ranks)