Update unwrap from accelerate (#29933)
* Use unwrap with the one in accelerate * oups * update unwrap * fix * wording * raise error instead * comment * doc * Update src/transformers/modeling_utils.py Co-authored-by: Zach Mueller <muellerzr@gmail.com> * style * put else --------- Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
@@ -123,7 +123,6 @@ if is_torch_available():
|
||||
Trainer,
|
||||
TrainerState,
|
||||
)
|
||||
from transformers.modeling_utils import unwrap_model
|
||||
from transformers.trainer_pt_utils import AcceleratorConfig
|
||||
|
||||
if is_safetensors_available():
|
||||
@@ -2468,8 +2467,10 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
trainer = get_regression_trainer(learning_rate=0.1)
|
||||
|
||||
def assert_flos_extraction(trainer, wrapped_model_to_check):
|
||||
self.assertEqual(trainer.model, unwrap_model(wrapped_model_to_check))
|
||||
self.assertGreaterEqual(getattr(unwrap_model(wrapped_model_to_check).config, "total_flos", 0), 0)
|
||||
self.assertEqual(trainer.model, trainer.accelerator.unwrap_model(wrapped_model_to_check))
|
||||
self.assertGreaterEqual(
|
||||
getattr(trainer.accelerator.unwrap_model(wrapped_model_to_check).config, "total_flos", 0), 0
|
||||
)
|
||||
|
||||
# with plain model
|
||||
assert_flos_extraction(trainer, trainer.model)
|
||||
|
||||
Reference in New Issue
Block a user