From 4bf5042240d33286460b83f3dbf9be77500faab3 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Tue, 12 May 2020 09:11:50 -0400 Subject: [PATCH] Fix BART tests on GPU (#4298) --- src/transformers/modeling_bart.py | 2 +- tests/test_modeling_bart.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index aa7c2b23d0..a461a6b478 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -886,7 +886,7 @@ class BartForConditionalGeneration(PretrainedBartModel): if new_num_tokens <= old_num_tokens: new_bias = self.final_logits_bias[:, :new_num_tokens] else: - extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens)) + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) self.register_buffer("final_logits_bias", new_bias) diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index e9a3d50149..0724e18efd 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -690,4 +690,8 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase): # test that forward pass is just a lookup, there is no ignore padding logic input_ids = torch.tensor([[4, 10, pad, pad, pad]], dtype=torch.long, device=torch_device) no_cache_pad_zero = emb1(input_ids) - self.assertTrue(torch.allclose(torch.Tensor(self.desired_weights), no_cache_pad_zero[:3, :5], atol=1e-3)) + self.assertTrue( + torch.allclose( + torch.tensor(self.desired_weights, device=torch_device), no_cache_pad_zero[:3, :5], atol=1e-3 + ) + )