[CTRL] warn if generation prompt does not start with a control code
see also https://github.com/salesforce/ctrl/pull/50
This commit is contained in:
@@ -101,7 +101,7 @@ python run_lm_finetuning.py \
|
||||
|
||||
Based on the script [`run_generation.py`](https://github.com/huggingface/transformers/blob/master/examples/run_generation.py).
|
||||
|
||||
Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL and XLNet.
|
||||
Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL, XLNet, CTRL.
|
||||
A similar script is used for our official demo [Write With Transfomer](https://transformer.huggingface.co), where you
|
||||
can try out the different models available in the library.
|
||||
|
||||
|
||||
@@ -196,7 +196,7 @@ def main():
|
||||
|
||||
logger.info(args)
|
||||
if args.model_type in ["ctrl"]:
|
||||
if args.temperature > 0.7 :
|
||||
if args.temperature > 0.7:
|
||||
logger.info('CTRL typically works better with lower temperatures (and lower top_k).')
|
||||
|
||||
while True:
|
||||
@@ -224,6 +224,9 @@ def main():
|
||||
# Models with memory likes to have a long prompt for short inputs.
|
||||
raw_text = (args.padding_text if args.padding_text else PADDING_TEXT) + raw_text
|
||||
context_tokens = tokenizer.encode(raw_text)
|
||||
if args.model_type == "ctrl":
|
||||
if not any(context_tokens[0] == x for x in tokenizer.control_codes.values()):
|
||||
logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
|
||||
out = sample_sequence(
|
||||
model=model,
|
||||
context=context_tokens,
|
||||
|
||||
Reference in New Issue
Block a user