cleanup deltas
This commit is contained in:
@@ -243,15 +243,15 @@ class BartHeadTests(unittest.TestCase):
|
||||
decoder_ffn_dim=32,
|
||||
max_position_embeddings=48,
|
||||
)
|
||||
lm_model = BartForMaskedLM(config).to(torch_device)
|
||||
context = _long_tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]])
|
||||
summary = _long_tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]])
|
||||
lm_model = BartForMaskedLM(config)
|
||||
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long()
|
||||
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long()
|
||||
logits, enc_features = lm_model.forward(input_ids=context, decoder_input_ids=summary)
|
||||
expected_shape = (*summary.shape, config.vocab_size)
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
||||
def test_generate_beam_search(self):
|
||||
input_ids = _long_tensor([[71, 82, 2], [68, 34, 2]])
|
||||
input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long()
|
||||
config = BartConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
d_model=24,
|
||||
@@ -264,7 +264,7 @@ class BartHeadTests(unittest.TestCase):
|
||||
max_position_embeddings=48,
|
||||
output_past=True,
|
||||
)
|
||||
lm_model = BartForMaskedLM(config).to(torch_device)
|
||||
lm_model = BartForMaskedLM(config)
|
||||
lm_model.eval()
|
||||
|
||||
new_input_ids = lm_model.generate(
|
||||
|
||||
Reference in New Issue
Block a user