Update usage doc regarding generate fn (#3504)

This commit is contained in:
Patrick von Platen
2020-03-31 15:31:46 +02:00
committed by GitHub
parent 57b0fab692
commit 42e1e3c67f

View File

@@ -420,7 +420,7 @@ to generate the tokens following the initial sequence in PyTorch, and creating a
sequence = f"Hugging Face is based in DUMBO, New York City, and is"
input = tokenizer.encode(sequence, return_tensors="pt")
generated = model.generate(input, max_length=50)
generated = model.generate(input, max_length=50, do_sample=True)
resulting_string = tokenizer.decode(generated.tolist()[0])
print(resulting_string)
@@ -432,14 +432,10 @@ to generate the tokens following the initial sequence in PyTorch, and creating a
model = TFAutoModelWithLMHead.from_pretrained("gpt2")
sequence = f"Hugging Face is based in DUMBO, New York City, and is"
generated = tokenizer.encode(sequence)
input = tokenizer.encode(sequence, return_tensors="tf")
generated = model.generate(input, max_length=50, do_sample=True)
for i in range(50):
predictions = model(tf.constant([generated]))[0]
token = tf.argmax(predictions[0], axis=1)[-1].numpy()
generated += [token]
resulting_string = tokenizer.decode(generated)
resulting_string = tokenizer.decode(generated.tolist()[0])
print(resulting_string)