make style (#11442)
This commit is contained in:
committed by
GitHub
parent
04ab2ca639
commit
32dbb2d954
@@ -393,7 +393,7 @@ class MaskedBertPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "bert"
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
|
||||
@@ -105,7 +105,7 @@ def regularization(model: nn.Module, mode: str):
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
""" Train the model """
|
||||
"""Train the model"""
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter(log_dir=args.output_dir)
|
||||
|
||||
|
||||
@@ -113,7 +113,7 @@ def to_list(tensor):
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
""" Train the model """
|
||||
"""Train the model"""
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter(log_dir=args.output_dir)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user