diff --git a/examples/run_transfo_xl.py b/examples/run_transfo_xl.py index b8000a2080..bf0d1a3d38 100644 --- a/examples/run_transfo_xl.py +++ b/examples/run_transfo_xl.py @@ -28,7 +28,7 @@ import math import torch -from pytorch_pretrained_bert import TransfoXLModel, TransfoXLCorpus +from pytorch_pretrained_bert import TransfoXLLMHeadModel, TransfoXLCorpus logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt = '%m/%d/%Y %H:%M:%S', @@ -79,7 +79,7 @@ def main(): device=device, ext_len=args.ext_len) # Load a pre-trained model - model = TransfoXLModel.from_pretrained(args.model_name) + model = TransfoXLLMHeadModel.from_pretrained(args.model_name) model = model.to(device) logger.info('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format( @@ -139,4 +139,4 @@ def main(): logger.info('=' * 100) if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/pytorch_pretrained_bert/modeling_transfo_xl_utilities.py b/pytorch_pretrained_bert/modeling_transfo_xl_utilities.py index 37c38d3776..0a65371c61 100644 --- a/pytorch_pretrained_bert/modeling_transfo_xl_utilities.py +++ b/pytorch_pretrained_bert/modeling_transfo_xl_utilities.py @@ -169,7 +169,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module): if i == 0: if target is not None: - logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1) + logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1) else: out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]] else: