Add support for multiple models for one config in auto classes (#11150)
* Add support for multiple models for one config in auto classes * Use get_values everywhere * Prettier doc
This commit is contained in:
@@ -13,7 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import is_tf_available
|
||||
@@ -39,6 +40,8 @@ if is_tf_available():
|
||||
TFBertForQuestionAnswering,
|
||||
TFBertForSequenceClassification,
|
||||
TFBertModel,
|
||||
TFFunnelBaseModel,
|
||||
TFFunnelModel,
|
||||
TFGPT2LMHeadModel,
|
||||
TFRobertaForMaskedLM,
|
||||
TFT5ForConditionalGeneration,
|
||||
@@ -176,6 +179,21 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
self.assertEqual(model.num_parameters(), 14410)
|
||||
self.assertEqual(model.num_parameters(only_trainable=True), 14410)
|
||||
|
||||
def test_from_pretrained_with_tuple_values(self):
|
||||
# For the auto model mapping, FunnelConfig has two models: FunnelModel and FunnelBaseModel
|
||||
model = TFAutoModel.from_pretrained("sgugger/funnel-random-tiny")
|
||||
self.assertIsInstance(model, TFFunnelModel)
|
||||
|
||||
config = copy.deepcopy(model.config)
|
||||
config.architectures = ["FunnelBaseModel"]
|
||||
model = TFAutoModel.from_config(config)
|
||||
self.assertIsInstance(model, TFFunnelBaseModel)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
model = TFAutoModel.from_pretrained(tmp_dir)
|
||||
self.assertIsInstance(model, TFFunnelBaseModel)
|
||||
|
||||
def test_parents_and_children_in_mappings(self):
|
||||
# Test that the children are placed before the parents in the mappings, as the `instanceof` will be triggered
|
||||
# by the parents and will return the wrong configuration type when using auto models
|
||||
@@ -197,4 +215,12 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
for parent_config, parent_model in mapping[: index + 1]:
|
||||
with self.subTest(msg=f"Testing if {child_config.__name__} is child of {parent_config.__name__}"):
|
||||
self.assertFalse(issubclass(child_config, parent_config))
|
||||
self.assertFalse(issubclass(child_model, parent_model))
|
||||
|
||||
# Tuplify child_model and parent_model since some of them could be tuples.
|
||||
if not isinstance(child_model, (list, tuple)):
|
||||
child_model = (child_model,)
|
||||
if not isinstance(parent_model, (list, tuple)):
|
||||
parent_model = (parent_model,)
|
||||
|
||||
for child, parent in [(a, b) for a in child_model for b in parent_model]:
|
||||
assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}"
|
||||
|
||||
Reference in New Issue
Block a user