Add support for multiple models for one config in auto classes (#11150)
* Add support for multiple models for one config in auto classes * Use get_values everywhere * Prettier doc
This commit is contained in:
@@ -13,7 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
@@ -46,6 +47,8 @@ if is_torch_available():
|
||||
BertForSequenceClassification,
|
||||
BertForTokenClassification,
|
||||
BertModel,
|
||||
FunnelBaseModel,
|
||||
FunnelModel,
|
||||
GPT2Config,
|
||||
GPT2LMHeadModel,
|
||||
RobertaForMaskedLM,
|
||||
@@ -218,6 +221,21 @@ class AutoModelTest(unittest.TestCase):
|
||||
self.assertEqual(model.num_parameters(), 14410)
|
||||
self.assertEqual(model.num_parameters(only_trainable=True), 14410)
|
||||
|
||||
def test_from_pretrained_with_tuple_values(self):
|
||||
# For the auto model mapping, FunnelConfig has two models: FunnelModel and FunnelBaseModel
|
||||
model = AutoModel.from_pretrained("sgugger/funnel-random-tiny")
|
||||
self.assertIsInstance(model, FunnelModel)
|
||||
|
||||
config = copy.deepcopy(model.config)
|
||||
config.architectures = ["FunnelBaseModel"]
|
||||
model = AutoModel.from_config(config)
|
||||
self.assertIsInstance(model, FunnelBaseModel)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
model = AutoModel.from_pretrained(tmp_dir)
|
||||
self.assertIsInstance(model, FunnelBaseModel)
|
||||
|
||||
def test_parents_and_children_in_mappings(self):
|
||||
# Test that the children are placed before the parents in the mappings, as the `instanceof` will be triggered
|
||||
# by the parents and will return the wrong configuration type when using auto models
|
||||
@@ -242,6 +260,12 @@ class AutoModelTest(unittest.TestCase):
|
||||
assert not issubclass(
|
||||
child_config, parent_config
|
||||
), f"{child_config.__name__} is child of {parent_config.__name__}"
|
||||
assert not issubclass(
|
||||
child_model, parent_model
|
||||
), f"{child_config.__name__} is child of {parent_config.__name__}"
|
||||
|
||||
# Tuplify child_model and parent_model since some of them could be tuples.
|
||||
if not isinstance(child_model, (list, tuple)):
|
||||
child_model = (child_model,)
|
||||
if not isinstance(parent_model, (list, tuple)):
|
||||
parent_model = (parent_model,)
|
||||
|
||||
for child, parent in [(a, b) for a in child_model for b in parent_model]:
|
||||
assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}"
|
||||
|
||||
Reference in New Issue
Block a user