Safetensors serialization by default (#27064)
* Safetensors serialization by default * First pass on the tests * Second pass on the tests * Third pass on the tests * Fix TF weight loading from TF-format safetensors * Specific encoder-decoder fixes for weight crossloading * Add VisionEncoderDecoder fixes for TF too * Change filename test for pt-to-tf * One missing fix for TFVisionEncoderDecoder * Fix the other crossload test * Support for flax + updated tests * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Sanchit's comments * Sanchit's comments 2 * Nico's comments * Fix tests * cleanup * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: Matt <rocketknight1@gmail.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -211,6 +211,8 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
config = copy.deepcopy(model.config)
|
||||
config.architectures = ["FunnelBaseModel"]
|
||||
model = TFAutoModel.from_config(config)
|
||||
model.build()
|
||||
|
||||
self.assertIsInstance(model, TFFunnelBaseModel)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@@ -245,7 +247,10 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
# Now that the config is registered, it can be used as any other config with the auto-API
|
||||
tiny_config = BertModelTester(self).get_config()
|
||||
config = NewModelConfig(**tiny_config.to_dict())
|
||||
|
||||
model = auto_class.from_config(config)
|
||||
model.build()
|
||||
|
||||
self.assertIsInstance(model, TFNewModel)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
||||
Reference in New Issue
Block a user