[trainer] self.model_wrapped + _model_unwrap (#9390)
* model wrapped + model_unwrap * cleanup * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * style * deprecation warning * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -53,6 +53,7 @@ if is_torch_available():
|
||||
Trainer,
|
||||
TrainerState,
|
||||
)
|
||||
from transformers.trainer import _model_unwrap
|
||||
|
||||
|
||||
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
|
||||
@@ -850,8 +851,8 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
trainer = get_regression_trainer(learning_rate=0.1)
|
||||
|
||||
def assert_flos_extraction(trainer, wrapped_model_to_check):
|
||||
self.assertEqual(trainer.model, trainer._actual_model(wrapped_model_to_check))
|
||||
self.assertGreaterEqual(getattr(trainer._actual_model(wrapped_model_to_check).config, "total_flos", 0), 0)
|
||||
self.assertEqual(trainer.model, _model_unwrap(wrapped_model_to_check))
|
||||
self.assertGreaterEqual(getattr(_model_unwrap(wrapped_model_to_check).config, "total_flos", 0), 0)
|
||||
|
||||
# with plain model
|
||||
assert_flos_extraction(trainer, trainer.model)
|
||||
|
||||
Reference in New Issue
Block a user