Fix top k generation for k != 0

This commit is contained in:
Catalin Voss
2019-03-02 21:54:44 -08:00
committed by GitHub
parent 2152bfeae8
commit 4b4b079272

View File

@@ -16,11 +16,17 @@ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(messa
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def top_k_logits(logits, k): def top_k_logits(logits, k):
"""
Masks everything but the k top entries as -infinity (1e10).
Used to mask logits such that e^-infinity -> 0 won't contribute to the
sum of the denominator.
"""
if k == 0: if k == 0:
return logits return logits
values, _ = torch.topk(logits, k) else:
min_values = values[:, -1] values = torch.topk(logits, k)[0]
return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits) batch_mins = values[:, -1].view(-1, 1).expand_as(logits)
return torch.where(logits < batch_mins, torch.ones_like(logits) * -1e10, logits)
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda', sample=True): def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda', sample=True):
if start_token is None: if start_token is None: