tests pass

This commit is contained in:
sshleifer
2020-03-05 12:33:08 -05:00
parent 7ac47bfe69
commit c36fdc88d4
3 changed files with 25 additions and 10 deletions

View File

@@ -243,15 +243,15 @@ class BartHeadTests(unittest.TestCase):
decoder_ffn_dim=32,
max_position_embeddings=48,
)
lm_model = BartForMaskedLM(config)
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long()
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long()
lm_model = BartForMaskedLM(config).to(torch_device)
context = _long_tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]])
summary = _long_tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]])
logits, enc_features = lm_model.forward(input_ids=context, decoder_input_ids=summary)
expected_shape = (*summary.shape, config.vocab_size)
self.assertEqual(logits.shape, expected_shape)
def test_generate_beam_search(self):
input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long()
input_ids = _long_tensor([[71, 82, 2], [68, 34, 2]])
config = BartConfig(
vocab_size=self.vocab_size,
d_model=24,
@@ -264,7 +264,7 @@ class BartHeadTests(unittest.TestCase):
max_position_embeddings=48,
output_past=True,
)
lm_model = BartForMaskedLM(config)
lm_model = BartForMaskedLM(config).to(torch_device)
lm_model.eval()
new_input_ids = lm_model.generate(
@@ -294,6 +294,13 @@ class BartHeadTests(unittest.TestCase):
bart_toks = tokenizer.encode(ex, return_tensors="pt")
_assert_tensors_equal(desired_result.long(), bart_toks, prefix=ex)
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
def test_generate_fp16(self):
config, input_ids, batch_size = self._get_config_and_data(output_past=True)
attention_mask = input_ids.ne(1)
lm_model = BartForMaskedLM(config).eval().to(torch_device).half()
lm_model.generate(input_ids, attention_mask)
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""