Black 20 release
This commit is contained in:
@@ -181,7 +181,10 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True,
|
||||
model,
|
||||
device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=True,
|
||||
)
|
||||
|
||||
# Train!
|
||||
@@ -304,16 +307,22 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
attention_mask=inputs["attention_mask"],
|
||||
)
|
||||
|
||||
loss_start = F.kl_div(
|
||||
input=F.log_softmax(start_logits_stu / args.temperature, dim=-1),
|
||||
target=F.softmax(start_logits_tea / args.temperature, dim=-1),
|
||||
reduction="batchmean",
|
||||
) * (args.temperature ** 2)
|
||||
loss_end = F.kl_div(
|
||||
input=F.log_softmax(end_logits_stu / args.temperature, dim=-1),
|
||||
target=F.softmax(end_logits_tea / args.temperature, dim=-1),
|
||||
reduction="batchmean",
|
||||
) * (args.temperature ** 2)
|
||||
loss_start = (
|
||||
F.kl_div(
|
||||
input=F.log_softmax(start_logits_stu / args.temperature, dim=-1),
|
||||
target=F.softmax(start_logits_tea / args.temperature, dim=-1),
|
||||
reduction="batchmean",
|
||||
)
|
||||
* (args.temperature ** 2)
|
||||
)
|
||||
loss_end = (
|
||||
F.kl_div(
|
||||
input=F.log_softmax(end_logits_stu / args.temperature, dim=-1),
|
||||
target=F.softmax(end_logits_tea / args.temperature, dim=-1),
|
||||
reduction="batchmean",
|
||||
)
|
||||
* (args.temperature ** 2)
|
||||
)
|
||||
loss_logits = (loss_start + loss_end) / 2.0
|
||||
|
||||
loss = args.alpha_distil * loss_logits + args.alpha_ce * loss
|
||||
@@ -859,7 +868,10 @@ def main():
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.",
|
||||
"--num_train_epochs",
|
||||
default=3.0,
|
||||
type=float,
|
||||
help="Total number of training epochs to perform.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
|
||||
Reference in New Issue
Block a user