Rework TPU checkpointing in Trainer (#10504)

* Rework TPU checkpointing in Trainer

* Wraps the barrier in a dist test

* Address review comments

* Remove line
This commit is contained in:
Sylvain Gugger
2021-03-04 11:46:11 -05:00
committed by GitHub
parent 805c5200dc
commit 6290169eb3
4 changed files with 74 additions and 58 deletions

View File

@@ -57,7 +57,7 @@ if is_torch_available():
Trainer,
TrainerState,
)
from transformers.trainer import _model_unwrap
from transformers.modeling_utils import unwrap_model
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
@@ -882,8 +882,8 @@ class TrainerIntegrationTest(unittest.TestCase):
trainer = get_regression_trainer(learning_rate=0.1)
def assert_flos_extraction(trainer, wrapped_model_to_check):
self.assertEqual(trainer.model, _model_unwrap(wrapped_model_to_check))
self.assertGreaterEqual(getattr(_model_unwrap(wrapped_model_to_check).config, "total_flos", 0), 0)
self.assertEqual(trainer.model, unwrap_model(wrapped_model_to_check))
self.assertGreaterEqual(getattr(unwrap_model(wrapped_model_to_check).config, "total_flos", 0), 0)
# with plain model
assert_flos_extraction(trainer, trainer.model)