Rework TPU checkpointing in Trainer (#10504)
* Rework TPU checkpointing in Trainer * Wraps the barrier in a dist test * Address review comments * Remove line
This commit is contained in:
@@ -57,7 +57,7 @@ if is_torch_available():
|
||||
Trainer,
|
||||
TrainerState,
|
||||
)
|
||||
from transformers.trainer import _model_unwrap
|
||||
from transformers.modeling_utils import unwrap_model
|
||||
|
||||
|
||||
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
|
||||
@@ -882,8 +882,8 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
trainer = get_regression_trainer(learning_rate=0.1)
|
||||
|
||||
def assert_flos_extraction(trainer, wrapped_model_to_check):
|
||||
self.assertEqual(trainer.model, _model_unwrap(wrapped_model_to_check))
|
||||
self.assertGreaterEqual(getattr(_model_unwrap(wrapped_model_to_check).config, "total_flos", 0), 0)
|
||||
self.assertEqual(trainer.model, unwrap_model(wrapped_model_to_check))
|
||||
self.assertGreaterEqual(getattr(unwrap_model(wrapped_model_to_check).config, "total_flos", 0), 0)
|
||||
|
||||
# with plain model
|
||||
assert_flos_extraction(trainer, trainer.model)
|
||||
|
||||
Reference in New Issue
Block a user