[CTRL] warn if generation prompt does not start with a control code
see also https://github.com/salesforce/ctrl/pull/50
This commit is contained in:
@@ -413,7 +413,7 @@ and from the Salesforce CTRL model:
|
|||||||
python ./examples/run_generation.py \
|
python ./examples/run_generation.py \
|
||||||
--model_type=ctrl \
|
--model_type=ctrl \
|
||||||
--length=20 \
|
--length=20 \
|
||||||
--model_name_or_path=gpt2 \
|
--model_name_or_path=ctrl \
|
||||||
--temperature=0 \
|
--temperature=0 \
|
||||||
--repetition_penalty=1.2 \
|
--repetition_penalty=1.2 \
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ python run_lm_finetuning.py \
|
|||||||
|
|
||||||
Based on the script [`run_generation.py`](https://github.com/huggingface/transformers/blob/master/examples/run_generation.py).
|
Based on the script [`run_generation.py`](https://github.com/huggingface/transformers/blob/master/examples/run_generation.py).
|
||||||
|
|
||||||
Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL and XLNet.
|
Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL, XLNet, CTRL.
|
||||||
A similar script is used for our official demo [Write With Transfomer](https://transformer.huggingface.co), where you
|
A similar script is used for our official demo [Write With Transfomer](https://transformer.huggingface.co), where you
|
||||||
can try out the different models available in the library.
|
can try out the different models available in the library.
|
||||||
|
|
||||||
|
|||||||
@@ -196,7 +196,7 @@ def main():
|
|||||||
|
|
||||||
logger.info(args)
|
logger.info(args)
|
||||||
if args.model_type in ["ctrl"]:
|
if args.model_type in ["ctrl"]:
|
||||||
if args.temperature > 0.7 :
|
if args.temperature > 0.7:
|
||||||
logger.info('CTRL typically works better with lower temperatures (and lower top_k).')
|
logger.info('CTRL typically works better with lower temperatures (and lower top_k).')
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@@ -224,6 +224,9 @@ def main():
|
|||||||
# Models with memory likes to have a long prompt for short inputs.
|
# Models with memory likes to have a long prompt for short inputs.
|
||||||
raw_text = (args.padding_text if args.padding_text else PADDING_TEXT) + raw_text
|
raw_text = (args.padding_text if args.padding_text else PADDING_TEXT) + raw_text
|
||||||
context_tokens = tokenizer.encode(raw_text)
|
context_tokens = tokenizer.encode(raw_text)
|
||||||
|
if args.model_type == "ctrl":
|
||||||
|
if not any(context_tokens[0] == x for x in tokenizer.control_codes.values()):
|
||||||
|
logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
|
||||||
out = sample_sequence(
|
out = sample_sequence(
|
||||||
model=model,
|
model=model,
|
||||||
context=context_tokens,
|
context=context_tokens,
|
||||||
|
|||||||
@@ -46,6 +46,64 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
|||||||
'ctrl': 256,
|
'ctrl': 256,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CONTROL_CODES = {
|
||||||
|
"Pregnancy": 168629,
|
||||||
|
"Christianity": 7675,
|
||||||
|
"Explain": 106423,
|
||||||
|
"Fitness": 63440,
|
||||||
|
"Saving": 63163,
|
||||||
|
"Ask": 27171,
|
||||||
|
"Ass": 95985,
|
||||||
|
"Joke": 163509,
|
||||||
|
"Questions": 45622,
|
||||||
|
"Thoughts": 49605,
|
||||||
|
"Retail": 52342,
|
||||||
|
"Feminism": 164338,
|
||||||
|
"Writing": 11992,
|
||||||
|
"Atheism": 192263,
|
||||||
|
"Netflix": 48616,
|
||||||
|
"Computing": 39639,
|
||||||
|
"Opinion": 43213,
|
||||||
|
"Alone": 44967,
|
||||||
|
"Funny": 58917,
|
||||||
|
"Gaming": 40358,
|
||||||
|
"Human": 4088,
|
||||||
|
"India": 1331,
|
||||||
|
"Joker": 77138,
|
||||||
|
"Diet": 36206,
|
||||||
|
"Legal": 11859,
|
||||||
|
"Norman": 4939,
|
||||||
|
"Tip": 72689,
|
||||||
|
"Weight": 52343,
|
||||||
|
"Movies": 46273,
|
||||||
|
"Running": 23425,
|
||||||
|
"Science": 2090,
|
||||||
|
"Horror": 37793,
|
||||||
|
"Confession": 60572,
|
||||||
|
"Finance": 12250,
|
||||||
|
"Politics": 16360,
|
||||||
|
"Scary": 191985,
|
||||||
|
"Support": 12654,
|
||||||
|
"Technologies": 32516,
|
||||||
|
"Teenage": 66160,
|
||||||
|
"Event": 32769,
|
||||||
|
"Learned": 67460,
|
||||||
|
"Notion": 182770,
|
||||||
|
"Wikipedia": 37583,
|
||||||
|
"Books": 6665,
|
||||||
|
"Extract": 76050,
|
||||||
|
"Confessions": 102701,
|
||||||
|
"Conspiracy": 75932,
|
||||||
|
"Links": 63674,
|
||||||
|
"Narcissus": 150425,
|
||||||
|
"Relationship": 54766,
|
||||||
|
"Relationships": 134796,
|
||||||
|
"Reviews": 41671,
|
||||||
|
"News": 4256,
|
||||||
|
"Translation": 26820,
|
||||||
|
"multilingual": 128406,
|
||||||
|
}
|
||||||
|
|
||||||
def get_pairs(word):
|
def get_pairs(word):
|
||||||
"""Return set of symbol pairs in a word.
|
"""Return set of symbol pairs in a word.
|
||||||
|
|
||||||
@@ -68,6 +126,7 @@ class CTRLTokenizer(PreTrainedTokenizer):
|
|||||||
vocab_files_names = VOCAB_FILES_NAMES
|
vocab_files_names = VOCAB_FILES_NAMES
|
||||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
control_codes = CONTROL_CODES
|
||||||
|
|
||||||
def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
|
def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
|
||||||
super(CTRLTokenizer, self).__init__(unk_token=unk_token, **kwargs)
|
super(CTRLTokenizer, self).__init__(unk_token=unk_token, **kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user