This commit is contained in:
Lysandre
2019-10-22 14:12:33 -04:00
parent 44286b94d3
commit 7d709e55ed
10 changed files with 41 additions and 39 deletions

View File

@@ -223,7 +223,7 @@ def main():
if args.model_type in ["transfo-xl", "xlnet"]:
# Models with memory likes to have a long prompt for short inputs.
raw_text = (args.padding_text if args.padding_text else PADDING_TEXT) + raw_text
context_tokens = tokenizer.encode(raw_text)
context_tokens = tokenizer.encode(raw_text, add_special_tokens=False)
out = sample_sequence(
model=model,
context=context_tokens,