Fix BART tests on GPU (#4298)

This commit is contained in:
Julien Chaumond
2020-05-12 09:11:50 -04:00
committed by GitHub
parent e4512aab3b
commit 4bf5042240
2 changed files with 6 additions and 2 deletions

View File

@@ -886,7 +886,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
if new_num_tokens <= old_num_tokens:
new_bias = self.final_logits_bias[:, :new_num_tokens]
else:
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens))
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
self.register_buffer("final_logits_bias", new_bias)