num_parameters helper

This commit is contained in:
Julien Chaumond
2020-01-10 17:40:02 +00:00
parent 331065e62d
commit 84c0aa1868
4 changed files with 35 additions and 2 deletions

View File

@@ -100,3 +100,5 @@ class AutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO)
model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(model, BertForMaskedLM)
self.assertEqual(model.num_parameters(), 14830)
self.assertEqual(model.num_parameters(only_trainable=True), 14830)

View File

@@ -99,3 +99,5 @@ class TFAutoModelTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO)
model = TFAutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(model, TFBertForMaskedLM)
self.assertEqual(model.num_parameters(), 14830)
self.assertEqual(model.num_parameters(only_trainable=True), 14830)