Reformat source code with black.
This is the result of:
$ black --line-length 119 examples templates transformers utils hubconf.py setup.py
There's a lot of fairly long lines in the project. As a consequence, I'm
picking the longest widely accepted line length, 119 characters.
This is also Thomas' preference, because it allows for explicit variable
names, to make the code easier to understand.
This commit is contained in:
@@ -23,9 +23,12 @@ import torch
|
||||
import numpy as np
|
||||
|
||||
import logging
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO)
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -35,12 +38,12 @@ def git_log(folder_path: str):
|
||||
"""
|
||||
repo = git.Repo(search_parent_directories=True)
|
||||
repo_infos = {
|
||||
'repo_id': str(repo),
|
||||
'repo_sha': str(repo.head.object.hexsha),
|
||||
'repo_branch': str(repo.active_branch)
|
||||
"repo_id": str(repo),
|
||||
"repo_sha": str(repo.head.object.hexsha),
|
||||
"repo_branch": str(repo.active_branch),
|
||||
}
|
||||
|
||||
with open(os.path.join(folder_path, 'git_log.json'), 'w') as f:
|
||||
with open(os.path.join(folder_path, "git_log.json"), "w") as f:
|
||||
json.dump(repo_infos, f, indent=4)
|
||||
|
||||
|
||||
@@ -57,21 +60,21 @@ def init_gpu_params(params):
|
||||
|
||||
assert torch.cuda.is_available()
|
||||
|
||||
logger.info('Initializing GPUs')
|
||||
logger.info("Initializing GPUs")
|
||||
if params.n_gpu > 1:
|
||||
assert params.local_rank != -1
|
||||
|
||||
params.world_size = int(os.environ['WORLD_SIZE'])
|
||||
params.n_gpu_per_node = int(os.environ['N_GPU_NODE'])
|
||||
params.global_rank = int(os.environ['RANK'])
|
||||
params.world_size = int(os.environ["WORLD_SIZE"])
|
||||
params.n_gpu_per_node = int(os.environ["N_GPU_NODE"])
|
||||
params.global_rank = int(os.environ["RANK"])
|
||||
|
||||
# number of nodes / node ID
|
||||
params.n_nodes = params.world_size // params.n_gpu_per_node
|
||||
params.node_id = params.global_rank // params.n_gpu_per_node
|
||||
params.multi_gpu = True
|
||||
|
||||
assert params.n_nodes == int(os.environ['N_NODES'])
|
||||
assert params.node_id == int(os.environ['NODE_RANK'])
|
||||
assert params.n_nodes == int(os.environ["N_NODES"])
|
||||
assert params.node_id == int(os.environ["NODE_RANK"])
|
||||
|
||||
# local job (single GPU)
|
||||
else:
|
||||
@@ -114,8 +117,7 @@ def init_gpu_params(params):
|
||||
if params.multi_gpu:
|
||||
logger.info("Initializing PyTorch distributed")
|
||||
torch.distributed.init_process_group(
|
||||
init_method='env://',
|
||||
backend='nccl',
|
||||
init_method="env://", backend="nccl",
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user