Merge pull request #338 from CatalinVoss/patch-3
Fix top k generation for k != 0
This commit is contained in:
@@ -16,11 +16,17 @@ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(messa
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def top_k_logits(logits, k):
|
def top_k_logits(logits, k):
|
||||||
|
"""
|
||||||
|
Masks everything but the k top entries as -infinity (1e10).
|
||||||
|
Used to mask logits such that e^-infinity -> 0 won't contribute to the
|
||||||
|
sum of the denominator.
|
||||||
|
"""
|
||||||
if k == 0:
|
if k == 0:
|
||||||
return logits
|
return logits
|
||||||
values, _ = torch.topk(logits, k)
|
else:
|
||||||
min_values = values[:, -1]
|
values = torch.topk(logits, k)[0]
|
||||||
return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits)
|
batch_mins = values[:, -1].view(-1, 1).expand_as(logits)
|
||||||
|
return torch.where(logits < batch_mins, torch.ones_like(logits) * -1e10, logits)
|
||||||
|
|
||||||
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda', sample=True):
|
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda', sample=True):
|
||||||
if start_token is None:
|
if start_token is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user