From 87c8fca9bc39435f518e0b60e44aafc374333886 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Wed, 25 Dec 2019 02:27:25 +0100 Subject: [PATCH] add example for ctrl text generation in docs --- src/transformers/modeling_utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d4421a05a6..5a36b436be 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -624,6 +624,14 @@ class PreTrainedModel(nn.Module): input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context outputs = model.generate(input_ids=input_ids, max_length=40, do_sample=True, temperature=0.7, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id, num_beams=3) # generate sequences using beam search decoding (3 beams) print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) + + tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer + model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache. + input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl + input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context + outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences using using greedy search + print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) + """ # We cannot generate if the model does not have a LM head