update transfo xl example

This commit is contained in:
thomwolf
2019-02-09 16:59:17 +01:00
parent 1320e4ec0c
commit 6cd769957e
2 changed files with 4 additions and 4 deletions

View File

@@ -28,7 +28,7 @@ import math
import torch
from pytorch_pretrained_bert import TransfoXLModel, TransfoXLCorpus
from pytorch_pretrained_bert import TransfoXLLMHeadModel, TransfoXLCorpus
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
@@ -79,7 +79,7 @@ def main():
device=device, ext_len=args.ext_len)
# Load a pre-trained model
model = TransfoXLModel.from_pretrained(args.model_name)
model = TransfoXLLMHeadModel.from_pretrained(args.model_name)
model = model.to(device)
logger.info('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format(
@@ -139,4 +139,4 @@ def main():
logger.info('=' * 100)
if __name__ == '__main__':
main()
main()