Fix BART tests on GPU (#4298)
This commit is contained in:
@@ -886,7 +886,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||||||
if new_num_tokens <= old_num_tokens:
|
if new_num_tokens <= old_num_tokens:
|
||||||
new_bias = self.final_logits_bias[:, :new_num_tokens]
|
new_bias = self.final_logits_bias[:, :new_num_tokens]
|
||||||
else:
|
else:
|
||||||
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens))
|
extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
|
||||||
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
|
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
|
||||||
self.register_buffer("final_logits_bias", new_bias)
|
self.register_buffer("final_logits_bias", new_bias)
|
||||||
|
|
||||||
|
|||||||
@@ -690,4 +690,8 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
|
|||||||
# test that forward pass is just a lookup, there is no ignore padding logic
|
# test that forward pass is just a lookup, there is no ignore padding logic
|
||||||
input_ids = torch.tensor([[4, 10, pad, pad, pad]], dtype=torch.long, device=torch_device)
|
input_ids = torch.tensor([[4, 10, pad, pad, pad]], dtype=torch.long, device=torch_device)
|
||||||
no_cache_pad_zero = emb1(input_ids)
|
no_cache_pad_zero = emb1(input_ids)
|
||||||
self.assertTrue(torch.allclose(torch.Tensor(self.desired_weights), no_cache_pad_zero[:3, :5], atol=1e-3))
|
self.assertTrue(
|
||||||
|
torch.allclose(
|
||||||
|
torch.tensor(self.desired_weights, device=torch_device), no_cache_pad_zero[:3, :5], atol=1e-3
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user