From f3e0218fbb6bcc40b40f10089dae8876654edb23 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Sat, 5 Oct 2019 21:05:16 -0400 Subject: [PATCH] Correct device assignment in run_generation --- examples/run_generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/run_generation.py b/examples/run_generation.py index 33a0ae1816..de2f6b8869 100644 --- a/examples/run_generation.py +++ b/examples/run_generation.py @@ -125,7 +125,7 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k= inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping} if xlm_lang is not None: - inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1]).view(1, -1) + inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1) outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states) next_token_logits = outputs[0][0, -1, :] / temperature