Correct device assignment in run_generation
This commit is contained in:
@@ -125,7 +125,7 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
|
||||
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
|
||||
|
||||
if xlm_lang is not None:
|
||||
inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1]).view(1, -1)
|
||||
inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1)
|
||||
|
||||
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
|
||||
next_token_logits = outputs[0][0, -1, :] / temperature
|
||||
|
||||
Reference in New Issue
Block a user