make style (#11442)
This commit is contained in:
committed by
GitHub
parent
04ab2ca639
commit
32dbb2d954
@@ -49,14 +49,14 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def entropy(p):
|
||||
""" Compute the entropy of a probability distribution """
|
||||
"""Compute the entropy of a probability distribution"""
|
||||
plogp = p * torch.log(p)
|
||||
plogp[p == 0] = 0
|
||||
return -plogp.sum(dim=-1)
|
||||
|
||||
|
||||
def print_2d_tensor(tensor):
|
||||
""" Print a 2D tensor """
|
||||
"""Print a 2D tensor"""
|
||||
logger.info("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor))))
|
||||
for row in range(len(tensor)):
|
||||
if tensor.dtype != torch.long:
|
||||
|
||||
@@ -36,7 +36,7 @@ def save_model(model, dirpath):
|
||||
|
||||
|
||||
def entropy(p, unlogit=False):
|
||||
""" Compute the entropy of a probability distribution """
|
||||
"""Compute the entropy of a probability distribution"""
|
||||
exponent = 2
|
||||
if unlogit:
|
||||
p = torch.pow(p, exponent)
|
||||
@@ -46,7 +46,7 @@ def entropy(p, unlogit=False):
|
||||
|
||||
|
||||
def print_2d_tensor(tensor):
|
||||
""" Print a 2D tensor """
|
||||
"""Print a 2D tensor"""
|
||||
logger.info("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor))))
|
||||
for row in range(len(tensor)):
|
||||
if tensor.dtype != torch.long:
|
||||
|
||||
Reference in New Issue
Block a user