add example for ctrl text generation in docs

This commit is contained in:
patrickvonplaten
2019-12-25 02:27:25 +01:00
parent 88def24c45
commit 87c8fca9bc

View File

@@ -624,6 +624,14 @@ class PreTrainedModel(nn.Module):
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
outputs = model.generate(input_ids=input_ids, max_length=40, do_sample=True, temperature=0.7, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id, num_beams=3) # generate sequences using beam search decoding (3 beams) outputs = model.generate(input_ids=input_ids, max_length=40, do_sample=True, temperature=0.7, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id, num_beams=3) # generate sequences using beam search decoding (3 beams)
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True))) print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences using using greedy search
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
""" """
# We cannot generate if the model does not have a LM head # We cannot generate if the model does not have a LM head