[trainer] self.model_wrapped + _model_unwrap (#9390)

* model wrapped + model_unwrap

* cleanup

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* style

* deprecation warning

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Stas Bekman
2021-01-06 03:50:11 -08:00
committed by GitHub
parent 453a70d4cb
commit 9f675b05d4
2 changed files with 63 additions and 51 deletions

View File

@@ -53,6 +53,7 @@ if is_torch_available():
Trainer,
TrainerState,
)
from transformers.trainer import _model_unwrap
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
@@ -850,8 +851,8 @@ class TrainerIntegrationTest(unittest.TestCase):
trainer = get_regression_trainer(learning_rate=0.1)
def assert_flos_extraction(trainer, wrapped_model_to_check):
self.assertEqual(trainer.model, trainer._actual_model(wrapped_model_to_check))
self.assertGreaterEqual(getattr(trainer._actual_model(wrapped_model_to_check).config, "total_flos", 0), 0)
self.assertEqual(trainer.model, _model_unwrap(wrapped_model_to_check))
self.assertGreaterEqual(getattr(_model_unwrap(wrapped_model_to_check).config, "total_flos", 0), 0)
# with plain model
assert_flos_extraction(trainer, trainer.model)