[CTRL] warn if generation prompt does not start with a control code

see also https://github.com/salesforce/ctrl/pull/50
This commit is contained in:
Julien Chaumond
2019-10-22 21:27:20 +00:00
parent e16d46843a
commit ef1b8b2ae5
4 changed files with 65 additions and 3 deletions

View File

@@ -196,7 +196,7 @@ def main():
logger.info(args)
if args.model_type in ["ctrl"]:
if args.temperature > 0.7 :
if args.temperature > 0.7:
logger.info('CTRL typically works better with lower temperatures (and lower top_k).')
while True:
@@ -224,6 +224,9 @@ def main():
# Models with memory likes to have a long prompt for short inputs.
raw_text = (args.padding_text if args.padding_text else PADDING_TEXT) + raw_text
context_tokens = tokenizer.encode(raw_text)
if args.model_type == "ctrl":
if not any(context_tokens[0] == x for x in tokenizer.control_codes.values()):
logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
out = sample_sequence(
model=model,
context=context_tokens,