Don't push checkpoints to hub in no_trainer scripts (#16703)
Adds checkpoint prefixes to the gitignore if `push_to_hub` is used along with `checkpointint_steps`
This commit is contained in:
@@ -39,6 +39,7 @@ from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from accelerate.utils import set_seed
|
||||
from huggingface_hub import Repository
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
@@ -50,7 +51,6 @@ from transformers import (
|
||||
SchedulerType,
|
||||
default_data_collator,
|
||||
get_scheduler,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.utils import get_full_repo_name
|
||||
from transformers.utils.versions import require_version
|
||||
@@ -258,6 +258,12 @@ def main():
|
||||
else:
|
||||
repo_name = args.hub_model_id
|
||||
repo = Repository(args.output_dir, clone_from=repo_name)
|
||||
|
||||
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
||||
if "step_*" not in gitignore:
|
||||
gitignore.write("step_*\n")
|
||||
if "epoch_*" not in gitignore:
|
||||
gitignore.write("epoch_*\n")
|
||||
elif args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
accelerator.wait_for_everyone()
|
||||
@@ -542,7 +548,6 @@ def main():
|
||||
if args.output_dir is not None:
|
||||
output_dir = os.path.join(args.output_dir, output_dir)
|
||||
accelerator.save_state(output_dir)
|
||||
|
||||
if completed_steps >= args.max_train_steps:
|
||||
break
|
||||
|
||||
|
||||
@@ -39,6 +39,7 @@ from tqdm.auto import tqdm
|
||||
|
||||
import transformers
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from accelerate.utils import set_seed
|
||||
from huggingface_hub import Repository
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
@@ -50,7 +51,6 @@ from transformers import (
|
||||
DataCollatorForLanguageModeling,
|
||||
SchedulerType,
|
||||
get_scheduler,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.utils import get_full_repo_name
|
||||
from transformers.utils.versions import require_version
|
||||
@@ -269,6 +269,12 @@ def main():
|
||||
else:
|
||||
repo_name = args.hub_model_id
|
||||
repo = Repository(args.output_dir, clone_from=repo_name)
|
||||
|
||||
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
||||
if "step_*" not in gitignore:
|
||||
gitignore.write("step_*\n")
|
||||
if "epoch_*" not in gitignore:
|
||||
gitignore.write("epoch_*\n")
|
||||
elif args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
Reference in New Issue
Block a user