make style (#11442)
This commit is contained in:
committed by
GitHub
parent
04ab2ca639
commit
32dbb2d954
@@ -71,7 +71,7 @@ def set_seed(args):
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer):
|
||||
""" Train the model """
|
||||
"""Train the model"""
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter()
|
||||
|
||||
|
||||
@@ -251,7 +251,7 @@ class TransformerDecoder(nn.Module):
|
||||
return output, state # , state
|
||||
|
||||
def init_decoder_state(self, src, memory_bank, with_cache=False):
|
||||
""" Init decoder state """
|
||||
"""Init decoder state"""
|
||||
state = TransformerDecoderState(src)
|
||||
if with_cache:
|
||||
state._init_cache(memory_bank, self.num_layers)
|
||||
@@ -479,11 +479,11 @@ class MultiHeadedAttention(nn.Module):
|
||||
head_count = self.head_count
|
||||
|
||||
def shape(x):
|
||||
""" projection """
|
||||
"""projection"""
|
||||
return x.view(batch_size, -1, head_count, dim_per_head).transpose(1, 2)
|
||||
|
||||
def unshape(x):
|
||||
""" compute context """
|
||||
"""compute context"""
|
||||
return x.transpose(1, 2).contiguous().view(batch_size, -1, head_count * dim_per_head)
|
||||
|
||||
# 1) Project key, value, and query.
|
||||
@@ -571,12 +571,12 @@ class DecoderState(object):
|
||||
"""
|
||||
|
||||
def detach(self):
|
||||
""" Need to document this """
|
||||
"""Need to document this"""
|
||||
self.hidden = tuple([_.detach() for _ in self.hidden])
|
||||
self.input_feed = self.input_feed.detach()
|
||||
|
||||
def beam_update(self, idx, positions, beam_size):
|
||||
""" Need to document this """
|
||||
"""Need to document this"""
|
||||
for e in self._all:
|
||||
sizes = e.size()
|
||||
br = sizes[1]
|
||||
@@ -592,7 +592,7 @@ class DecoderState(object):
|
||||
|
||||
|
||||
class TransformerDecoderState(DecoderState):
|
||||
""" Transformer Decoder state base class """
|
||||
"""Transformer Decoder state base class"""
|
||||
|
||||
def __init__(self, src):
|
||||
"""
|
||||
@@ -638,7 +638,7 @@ class TransformerDecoderState(DecoderState):
|
||||
self.cache["layer_{}".format(l)] = layer_cache
|
||||
|
||||
def repeat_beam_size_times(self, beam_size):
|
||||
""" Repeat beam_size times along batch dimension. """
|
||||
"""Repeat beam_size times along batch dimension."""
|
||||
self.src = self.src.data.repeat(1, beam_size, 1)
|
||||
|
||||
def map_batch_fn(self, fn):
|
||||
|
||||
@@ -25,19 +25,19 @@ class SummarizationDataProcessingTest(unittest.TestCase):
|
||||
self.block_size = 10
|
||||
|
||||
def test_fit_to_block_sequence_too_small(self):
|
||||
""" Pad the sequence with 0 if the sequence is smaller than the block size."""
|
||||
"""Pad the sequence with 0 if the sequence is smaller than the block size."""
|
||||
sequence = [1, 2, 3, 4]
|
||||
expected_output = [1, 2, 3, 4, 0, 0, 0, 0, 0, 0]
|
||||
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
|
||||
|
||||
def test_fit_to_block_sequence_fit_exactly(self):
|
||||
""" Do nothing if the sequence is the right size. """
|
||||
"""Do nothing if the sequence is the right size."""
|
||||
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
|
||||
|
||||
def test_fit_to_block_sequence_too_big(self):
|
||||
""" Truncate the sequence if it is too long. """
|
||||
"""Truncate the sequence if it is too long."""
|
||||
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
|
||||
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
|
||||
|
||||
@@ -47,7 +47,7 @@ class CNNDMDataset(Dataset):
|
||||
self.documents.append(path_to_story)
|
||||
|
||||
def __len__(self):
|
||||
""" Returns the number of documents. """
|
||||
"""Returns the number of documents."""
|
||||
return len(self.documents)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
|
||||
@@ -49,14 +49,14 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def entropy(p):
|
||||
""" Compute the entropy of a probability distribution """
|
||||
"""Compute the entropy of a probability distribution"""
|
||||
plogp = p * torch.log(p)
|
||||
plogp[p == 0] = 0
|
||||
return -plogp.sum(dim=-1)
|
||||
|
||||
|
||||
def print_2d_tensor(tensor):
|
||||
""" Print a 2D tensor """
|
||||
"""Print a 2D tensor"""
|
||||
logger.info("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor))))
|
||||
for row in range(len(tensor)):
|
||||
if tensor.dtype != torch.long:
|
||||
|
||||
@@ -36,7 +36,7 @@ def save_model(model, dirpath):
|
||||
|
||||
|
||||
def entropy(p, unlogit=False):
|
||||
""" Compute the entropy of a probability distribution """
|
||||
"""Compute the entropy of a probability distribution"""
|
||||
exponent = 2
|
||||
if unlogit:
|
||||
p = torch.pow(p, exponent)
|
||||
@@ -46,7 +46,7 @@ def entropy(p, unlogit=False):
|
||||
|
||||
|
||||
def print_2d_tensor(tensor):
|
||||
""" Print a 2D tensor """
|
||||
"""Print a 2D tensor"""
|
||||
logger.info("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor))))
|
||||
for row in range(len(tensor)):
|
||||
if tensor.dtype != torch.long:
|
||||
|
||||
@@ -70,7 +70,7 @@ def get_wanted_result(result):
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer, train_highway=False):
|
||||
""" Train the model """
|
||||
"""Train the model"""
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter()
|
||||
|
||||
|
||||
@@ -92,7 +92,7 @@ def to_list(tensor):
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
""" Train the model """
|
||||
"""Train the model"""
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter()
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ def set_seed(args):
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer, criterion):
|
||||
""" Train the model """
|
||||
"""Train the model"""
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter()
|
||||
|
||||
|
||||
@@ -393,7 +393,7 @@ class MaskedBertPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "bert"
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
|
||||
@@ -105,7 +105,7 @@ def regularization(model: nn.Module, mode: str):
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
""" Train the model """
|
||||
"""Train the model"""
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter(log_dir=args.output_dir)
|
||||
|
||||
|
||||
@@ -113,7 +113,7 @@ def to_list(tensor):
|
||||
|
||||
|
||||
def train(args, train_dataset, model, tokenizer, teacher=None):
|
||||
""" Train the model """
|
||||
"""Train the model"""
|
||||
if args.local_rank in [-1, 0]:
|
||||
tb_writer = SummaryWriter(log_dir=args.output_dir)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user