cleanup tf unittests: part 2 (#6260)
* cleanup torch unittests: part 2 * remove trailing comma added by isort, and which breaks flake * one more comma * revert odd balls * part 3: odd cases * more ["key"] -> .key refactoring * .numpy() is not needed * more unncessary .numpy() removed * more simplification
This commit is contained in:
@@ -120,7 +120,7 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
|
||||
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device)
|
||||
result = lm_model(input_ids=context, decoder_input_ids=summary, labels=summary)
|
||||
expected_shape = (*summary.shape, config.vocab_size)
|
||||
self.assertEqual(result["logits"].shape, expected_shape)
|
||||
self.assertEqual(result.logits.shape, expected_shape)
|
||||
|
||||
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user