[CTRL] warn if generation prompt does not start with a control code
see also https://github.com/salesforce/ctrl/pull/50
This commit is contained in:
@@ -196,7 +196,7 @@ def main():
|
||||
|
||||
logger.info(args)
|
||||
if args.model_type in ["ctrl"]:
|
||||
if args.temperature > 0.7 :
|
||||
if args.temperature > 0.7:
|
||||
logger.info('CTRL typically works better with lower temperatures (and lower top_k).')
|
||||
|
||||
while True:
|
||||
@@ -224,6 +224,9 @@ def main():
|
||||
# Models with memory likes to have a long prompt for short inputs.
|
||||
raw_text = (args.padding_text if args.padding_text else PADDING_TEXT) + raw_text
|
||||
context_tokens = tokenizer.encode(raw_text)
|
||||
if args.model_type == "ctrl":
|
||||
if not any(context_tokens[0] == x for x in tokenizer.control_codes.values()):
|
||||
logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
|
||||
out = sample_sequence(
|
||||
model=model,
|
||||
context=context_tokens,
|
||||
|
||||
Reference in New Issue
Block a user