Cleanup TPU bits from run_glue.py
TPU runner is currently implemented in: https://github.com/pytorch-tpu/transformers/blob/tpu/examples/run_glue_tpu.py. We plan to upstream this directly into `huggingface/transformers` (either `master` or `tpu`) branch once it's been more thoroughly tested.
This commit is contained in:
committed by
Lysandre Debut
parent
454455c695
commit
e70cdf083d
@@ -158,7 +158,7 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
tr_loss += loss.item()
|
tr_loss += loss.item()
|
||||||
if (step + 1) % args.gradient_accumulation_steps == 0 and not args.tpu:
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||||
else:
|
else:
|
||||||
@@ -189,11 +189,6 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
|
||||||
logger.info("Saving model checkpoint to %s", output_dir)
|
logger.info("Saving model checkpoint to %s", output_dir)
|
||||||
|
|
||||||
if args.tpu:
|
|
||||||
args.xla_model.optimizer_step(optimizer, barrier=True)
|
|
||||||
model.zero_grad()
|
|
||||||
global_step += 1
|
|
||||||
|
|
||||||
if args.max_steps > 0 and global_step > args.max_steps:
|
if args.max_steps > 0 and global_step > args.max_steps:
|
||||||
epoch_iterator.close()
|
epoch_iterator.close()
|
||||||
break
|
break
|
||||||
@@ -397,15 +392,6 @@ def main():
|
|||||||
parser.add_argument('--seed', type=int, default=42,
|
parser.add_argument('--seed', type=int, default=42,
|
||||||
help="random seed for initialization")
|
help="random seed for initialization")
|
||||||
|
|
||||||
parser.add_argument('--tpu', action='store_true',
|
|
||||||
help="Whether to run on the TPU defined in the environment variables")
|
|
||||||
parser.add_argument('--tpu_ip_address', type=str, default='',
|
|
||||||
help="TPU IP address if none are set in the environment variables")
|
|
||||||
parser.add_argument('--tpu_name', type=str, default='',
|
|
||||||
help="TPU name if none are set in the environment variables")
|
|
||||||
parser.add_argument('--xrt_tpu_config', type=str, default='',
|
|
||||||
help="XRT TPU config if none are set in the environment variables")
|
|
||||||
|
|
||||||
parser.add_argument('--fp16', action='store_true',
|
parser.add_argument('--fp16', action='store_true',
|
||||||
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
|
||||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
||||||
@@ -439,23 +425,6 @@ def main():
|
|||||||
args.n_gpu = 1
|
args.n_gpu = 1
|
||||||
args.device = device
|
args.device = device
|
||||||
|
|
||||||
if args.tpu:
|
|
||||||
if args.tpu_ip_address:
|
|
||||||
os.environ["TPU_IP_ADDRESS"] = args.tpu_ip_address
|
|
||||||
if args.tpu_name:
|
|
||||||
os.environ["TPU_NAME"] = args.tpu_name
|
|
||||||
if args.xrt_tpu_config:
|
|
||||||
os.environ["XRT_TPU_CONFIG"] = args.xrt_tpu_config
|
|
||||||
|
|
||||||
assert "TPU_IP_ADDRESS" in os.environ
|
|
||||||
assert "TPU_NAME" in os.environ
|
|
||||||
assert "XRT_TPU_CONFIG" in os.environ
|
|
||||||
|
|
||||||
import torch_xla
|
|
||||||
import torch_xla.core.xla_model as xm
|
|
||||||
args.device = xm.xla_device()
|
|
||||||
args.xla_model = xm
|
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||||
@@ -509,7 +478,7 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
||||||
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0) and not args.tpu:
|
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
# Create output directory if needed
|
# Create output directory if needed
|
||||||
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
||||||
os.makedirs(args.output_dir)
|
os.makedirs(args.output_dir)
|
||||||
|
|||||||
Reference in New Issue
Block a user