cleanup tf unittests: part 2 (#6260)

* cleanup torch unittests: part 2

* remove trailing comma added by isort, and which breaks flake

* one more comma

* revert odd balls

* part 3: odd cases

* more ["key"] -> .key refactoring

* .numpy() is not needed

* more unncessary .numpy() removed

* more simplification
This commit is contained in:
Stas Bekman
2020-08-13 01:29:06 -07:00
committed by GitHub
parent bc820476a5
commit e983da0e7d
21 changed files with 159 additions and 239 deletions

View File

@@ -120,7 +120,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
result = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
expected_shape = (*summary.shape, config.vocab_size)
self.assertEqual(result["logits"].shape, expected_shape)
self.assertEqual(result.logits.shape, expected_shape)
@require_torch