Add support for multiple models for one config in auto classes (#11150)

* Add support for multiple models for one config in auto classes

* Use get_values everywhere

* Prettier doc
This commit is contained in:
Sylvain Gugger
2021-04-08 18:41:36 -04:00
committed by GitHub
parent 97ccf67bb3
commit ba8b1f4754
26 changed files with 188 additions and 72 deletions

View File

@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import tempfile
import unittest
from transformers import is_tf_available
@@ -39,6 +40,8 @@ if is_tf_available():
TFBertForQuestionAnswering,
TFBertForSequenceClassification,
TFBertModel,
TFFunnelBaseModel,
TFFunnelModel,
TFGPT2LMHeadModel,
TFRobertaForMaskedLM,
TFT5ForConditionalGeneration,
@@ -176,6 +179,21 @@ class TFAutoModelTest(unittest.TestCase):
self.assertEqual(model.num_parameters(), 14410)
self.assertEqual(model.num_parameters(only_trainable=True), 14410)
def test_from_pretrained_with_tuple_values(self):
# For the auto model mapping, FunnelConfig has two models: FunnelModel and FunnelBaseModel
model = TFAutoModel.from_pretrained("sgugger/funnel-random-tiny")
self.assertIsInstance(model, TFFunnelModel)
config = copy.deepcopy(model.config)
config.architectures = ["FunnelBaseModel"]
model = TFAutoModel.from_config(config)
self.assertIsInstance(model, TFFunnelBaseModel)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
model = TFAutoModel.from_pretrained(tmp_dir)
self.assertIsInstance(model, TFFunnelBaseModel)
def test_parents_and_children_in_mappings(self):
# Test that the children are placed before the parents in the mappings, as the `instanceof` will be triggered
# by the parents and will return the wrong configuration type when using auto models
@@ -197,4 +215,12 @@ class TFAutoModelTest(unittest.TestCase):
for parent_config, parent_model in mapping[: index + 1]:
with self.subTest(msg=f"Testing if {child_config.__name__} is child of {parent_config.__name__}"):
self.assertFalse(issubclass(child_config, parent_config))
self.assertFalse(issubclass(child_model, parent_model))
# Tuplify child_model and parent_model since some of them could be tuples.
if not isinstance(child_model, (list, tuple)):
child_model = (child_model,)
if not isinstance(parent_model, (list, tuple)):
parent_model = (parent_model,)
for child, parent in [(a, b) for a in child_model for b in parent_model]:
assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}"