Merge pull request #338 from CatalinVoss/patch-3
Fix top k generation for k != 0
This commit is contained in:
@@ -16,11 +16,17 @@ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(messa
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def top_k_logits(logits, k):
|
||||
"""
|
||||
Masks everything but the k top entries as -infinity (1e10).
|
||||
Used to mask logits such that e^-infinity -> 0 won't contribute to the
|
||||
sum of the denominator.
|
||||
"""
|
||||
if k == 0:
|
||||
return logits
|
||||
values, _ = torch.topk(logits, k)
|
||||
min_values = values[:, -1]
|
||||
return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits)
|
||||
else:
|
||||
values = torch.topk(logits, k)[0]
|
||||
batch_mins = values[:, -1].view(-1, 1).expand_as(logits)
|
||||
return torch.where(logits < batch_mins, torch.ones_like(logits) * -1e10, logits)
|
||||
|
||||
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda', sample=True):
|
||||
if start_token is None:
|
||||
|
||||
Reference in New Issue
Block a user