Fix bug in _sorted_checkpoints (#7880)
I'm using transformers 3.3.1 and run a training script with `--save_total_limit 3`. I hit the exception below, and after debugging the code found that it wrongly tries to index into the `best_model_checkpoint`'s *str* rather than the `sorted_checkpoints` array. When running without the fix I got this exception:
```
Traceback (most recent call last):
File "/<HOME>/.conda/envs/transformers/lib/python3.7/site-packages/transformers/trainer.py", line 921, in _save_training
self._rotate_checkpoints(use_mtime=True)
File "/<HOME>/.conda/envs/transformers/lib/python3.7/site-packages/transformers/trainer.py", line 1283, in _rotate_checkpoints
checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime)
File "/<HOME>/.conda/envs/transformers/lib/python3.7/site-packages/transformers/trainer.py", line 1274, in _sorted_checkpoints
checkpoints_sorted[best_model_index],
TypeError: 'str' object does not support item assignment
```
This commit is contained in:
@@ -1204,7 +1204,7 @@ class Trainer:
|
|||||||
# Make sure we don't delete the best model.
|
# Make sure we don't delete the best model.
|
||||||
if self.state.best_model_checkpoint is not None:
|
if self.state.best_model_checkpoint is not None:
|
||||||
best_model_index = checkpoints_sorted.index(self.state.best_model_checkpoint)
|
best_model_index = checkpoints_sorted.index(self.state.best_model_checkpoint)
|
||||||
checkpoints_sorted[best_model_index], checkpoints_sorted[best_model_index][-1] = (
|
checkpoints_sorted[best_model_index], checkpoints_sorted[-1] = (
|
||||||
checkpoints_sorted[-1],
|
checkpoints_sorted[-1],
|
||||||
checkpoints_sorted[best_model_index],
|
checkpoints_sorted[best_model_index],
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user