enable training mask2former and maskformer for transformers trainer (#28277)

* fix get_num_masks output as [int] to int

* fix loss size from torch.Size([1]) to torch.Size([])
This commit is contained in:
Sangbum Daniel Choi
2024-01-04 17:53:25 +09:00
committed by GitHub
parent 6b8ec2588e
commit 4a66c0d952
4 changed files with 4 additions and 4 deletions

View File

@@ -190,7 +190,7 @@ class Mask2FormerModelTester:
comm_check_on_output(result)
self.parent.assertTrue(result.loss is not None)
self.parent.assertEqual(result.loss.shape, torch.Size([1]))
self.parent.assertEqual(result.loss.shape, torch.Size([]))
@require_torch