Update unwrap from accelerate (#29933)

* Use unwrap with the one in accelerate

* oups

* update unwrap

* fix

* wording

* raise error instead

* comment

* doc

* Update src/transformers/modeling_utils.py

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

* style

* put else

---------

Co-authored-by: Zach Mueller <muellerzr@gmail.com>
This commit is contained in:
Marc Sun
2024-04-19 18:05:34 +02:00
committed by GitHub
parent fbd8c51ffc
commit b4fd49b6c5
3 changed files with 36 additions and 18 deletions

View File

@@ -123,7 +123,6 @@ if is_torch_available():
Trainer,
TrainerState,
)
from transformers.modeling_utils import unwrap_model
from transformers.trainer_pt_utils import AcceleratorConfig
if is_safetensors_available():
@@ -2468,8 +2467,10 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer = get_regression_trainer(learning_rate=0.1)
def assert_flos_extraction(trainer, wrapped_model_to_check):
self.assertEqual(trainer.model, unwrap_model(wrapped_model_to_check))
self.assertGreaterEqual(getattr(unwrap_model(wrapped_model_to_check).config, "total_flos", 0), 0)
self.assertEqual(trainer.model, trainer.accelerator.unwrap_model(wrapped_model_to_check))
self.assertGreaterEqual(
getattr(trainer.accelerator.unwrap_model(wrapped_model_to_check).config, "total_flos", 0), 0
)
# with plain model
assert_flos_extraction(trainer, trainer.model)