update transfo xl example
This commit is contained in:
@@ -28,7 +28,7 @@ import math
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_pretrained_bert import TransfoXLModel, TransfoXLCorpus
|
from pytorch_pretrained_bert import TransfoXLLMHeadModel, TransfoXLCorpus
|
||||||
|
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||||
@@ -79,7 +79,7 @@ def main():
|
|||||||
device=device, ext_len=args.ext_len)
|
device=device, ext_len=args.ext_len)
|
||||||
|
|
||||||
# Load a pre-trained model
|
# Load a pre-trained model
|
||||||
model = TransfoXLModel.from_pretrained(args.model_name)
|
model = TransfoXLLMHeadModel.from_pretrained(args.model_name)
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
logger.info('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format(
|
logger.info('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format(
|
||||||
@@ -139,4 +139,4 @@ def main():
|
|||||||
logger.info('=' * 100)
|
logger.info('=' * 100)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -169,7 +169,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
|
|||||||
|
|
||||||
if i == 0:
|
if i == 0:
|
||||||
if target is not None:
|
if target is not None:
|
||||||
logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1)
|
logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1)
|
||||||
else:
|
else:
|
||||||
out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]]
|
out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]]
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user