add example for ctrl text generation in docs
This commit is contained in:
@@ -624,6 +624,14 @@ class PreTrainedModel(nn.Module):
|
|||||||
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
|
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
|
||||||
outputs = model.generate(input_ids=input_ids, max_length=40, do_sample=True, temperature=0.7, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id, num_beams=3) # generate sequences using beam search decoding (3 beams)
|
outputs = model.generate(input_ids=input_ids, max_length=40, do_sample=True, temperature=0.7, bos_token_id=tokenizer.bos_token_id, eos_token_ids=tokenizer.eos_token_id, num_beams=3) # generate sequences using beam search decoding (3 beams)
|
||||||
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
|
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
|
||||||
|
model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
|
||||||
|
input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
|
||||||
|
input_ids = torch.tensor(tokenizer.encode(input_context)).unsqueeze(0) # encode input context
|
||||||
|
outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences using using greedy search
|
||||||
|
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# We cannot generate if the model does not have a LM head
|
# We cannot generate if the model does not have a LM head
|
||||||
|
|||||||
Reference in New Issue
Block a user