diff --git a/README.md b/README.md index e8506d6a39..ecba50a74e 100644 --- a/README.md +++ b/README.md @@ -413,7 +413,7 @@ and from the Salesforce CTRL model: python ./examples/run_generation.py \ --model_type=ctrl \ --length=20 \ - --model_name_or_path=gpt2 \ + --model_name_or_path=ctrl \ --temperature=0 \ --repetition_penalty=1.2 \ ``` diff --git a/examples/README.md b/examples/README.md index 6b68d880eb..3a76a4a830 100644 --- a/examples/README.md +++ b/examples/README.md @@ -101,7 +101,7 @@ python run_lm_finetuning.py \ Based on the script [`run_generation.py`](https://github.com/huggingface/transformers/blob/master/examples/run_generation.py). -Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL and XLNet. +Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL, XLNet, CTRL. A similar script is used for our official demo [Write With Transfomer](https://transformer.huggingface.co), where you can try out the different models available in the library. diff --git a/examples/run_generation.py b/examples/run_generation.py index ef58cfd844..ae0e27dcf0 100644 --- a/examples/run_generation.py +++ b/examples/run_generation.py @@ -196,7 +196,7 @@ def main(): logger.info(args) if args.model_type in ["ctrl"]: - if args.temperature > 0.7 : + if args.temperature > 0.7: logger.info('CTRL typically works better with lower temperatures (and lower top_k).') while True: @@ -224,6 +224,9 @@ def main(): # Models with memory likes to have a long prompt for short inputs. raw_text = (args.padding_text if args.padding_text else PADDING_TEXT) + raw_text context_tokens = tokenizer.encode(raw_text) + if args.model_type == "ctrl": + if not any(context_tokens[0] == x for x in tokenizer.control_codes.values()): + logger.info("WARNING! You are not starting your generation from a control code so you won't get good results") out = sample_sequence( model=model, context=context_tokens, diff --git a/transformers/tokenization_ctrl.py b/transformers/tokenization_ctrl.py index c8d67ad043..3d67fa2c5b 100644 --- a/transformers/tokenization_ctrl.py +++ b/transformers/tokenization_ctrl.py @@ -46,6 +46,64 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 'ctrl': 256, } +CONTROL_CODES = { + "Pregnancy": 168629, + "Christianity": 7675, + "Explain": 106423, + "Fitness": 63440, + "Saving": 63163, + "Ask": 27171, + "Ass": 95985, + "Joke": 163509, + "Questions": 45622, + "Thoughts": 49605, + "Retail": 52342, + "Feminism": 164338, + "Writing": 11992, + "Atheism": 192263, + "Netflix": 48616, + "Computing": 39639, + "Opinion": 43213, + "Alone": 44967, + "Funny": 58917, + "Gaming": 40358, + "Human": 4088, + "India": 1331, + "Joker": 77138, + "Diet": 36206, + "Legal": 11859, + "Norman": 4939, + "Tip": 72689, + "Weight": 52343, + "Movies": 46273, + "Running": 23425, + "Science": 2090, + "Horror": 37793, + "Confession": 60572, + "Finance": 12250, + "Politics": 16360, + "Scary": 191985, + "Support": 12654, + "Technologies": 32516, + "Teenage": 66160, + "Event": 32769, + "Learned": 67460, + "Notion": 182770, + "Wikipedia": 37583, + "Books": 6665, + "Extract": 76050, + "Confessions": 102701, + "Conspiracy": 75932, + "Links": 63674, + "Narcissus": 150425, + "Relationship": 54766, + "Relationships": 134796, + "Reviews": 41671, + "News": 4256, + "Translation": 26820, + "multilingual": 128406, +} + def get_pairs(word): """Return set of symbol pairs in a word. @@ -68,6 +126,7 @@ class CTRLTokenizer(PreTrainedTokenizer): vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + control_codes = CONTROL_CODES def __init__(self, vocab_file, merges_file, unk_token="", **kwargs): super(CTRLTokenizer, self).__init__(unk_token=unk_token, **kwargs)