make style (#11442)

This commit is contained in:
Patrick von Platen
2021-04-26 13:50:34 +02:00
committed by GitHub
parent 04ab2ca639
commit 32dbb2d954
105 changed files with 202 additions and 202 deletions

View File

@@ -251,7 +251,7 @@ class TransformerDecoder(nn.Module):
return output, state # , state
def init_decoder_state(self, src, memory_bank, with_cache=False):
""" Init decoder state """
"""Init decoder state"""
state = TransformerDecoderState(src)
if with_cache:
state._init_cache(memory_bank, self.num_layers)
@@ -479,11 +479,11 @@ class MultiHeadedAttention(nn.Module):
head_count = self.head_count
def shape(x):
""" projection """
"""projection"""
return x.view(batch_size, -1, head_count, dim_per_head).transpose(1, 2)
def unshape(x):
""" compute context """
"""compute context"""
return x.transpose(1, 2).contiguous().view(batch_size, -1, head_count * dim_per_head)
# 1) Project key, value, and query.
@@ -571,12 +571,12 @@ class DecoderState(object):
"""
def detach(self):
""" Need to document this """
"""Need to document this"""
self.hidden = tuple([_.detach() for _ in self.hidden])
self.input_feed = self.input_feed.detach()
def beam_update(self, idx, positions, beam_size):
""" Need to document this """
"""Need to document this"""
for e in self._all:
sizes = e.size()
br = sizes[1]
@@ -592,7 +592,7 @@ class DecoderState(object):
class TransformerDecoderState(DecoderState):
""" Transformer Decoder state base class """
"""Transformer Decoder state base class"""
def __init__(self, src):
"""
@@ -638,7 +638,7 @@ class TransformerDecoderState(DecoderState):
self.cache["layer_{}".format(l)] = layer_cache
def repeat_beam_size_times(self, beam_size):
""" Repeat beam_size times along batch dimension. """
"""Repeat beam_size times along batch dimension."""
self.src = self.src.data.repeat(1, beam_size, 1)
def map_batch_fn(self, fn):

View File

@@ -25,19 +25,19 @@ class SummarizationDataProcessingTest(unittest.TestCase):
self.block_size = 10
def test_fit_to_block_sequence_too_small(self):
""" Pad the sequence with 0 if the sequence is smaller than the block size."""
"""Pad the sequence with 0 if the sequence is smaller than the block size."""
sequence = [1, 2, 3, 4]
expected_output = [1, 2, 3, 4, 0, 0, 0, 0, 0, 0]
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
def test_fit_to_block_sequence_fit_exactly(self):
""" Do nothing if the sequence is the right size. """
"""Do nothing if the sequence is the right size."""
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)
def test_fit_to_block_sequence_too_big(self):
""" Truncate the sequence if it is too long. """
"""Truncate the sequence if it is too long."""
sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
expected_output = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
self.assertEqual(truncate_or_pad(sequence, self.block_size, 0), expected_output)

View File

@@ -47,7 +47,7 @@ class CNNDMDataset(Dataset):
self.documents.append(path_to_story)
def __len__(self):
""" Returns the number of documents. """
"""Returns the number of documents."""
return len(self.documents)
def __getitem__(self, idx):