Update usage doc regarding generate fn (#3504)
This commit is contained in:
committed by
GitHub
parent
57b0fab692
commit
42e1e3c67f
@@ -420,7 +420,7 @@ to generate the tokens following the initial sequence in PyTorch, and creating a
|
|||||||
sequence = f"Hugging Face is based in DUMBO, New York City, and is"
|
sequence = f"Hugging Face is based in DUMBO, New York City, and is"
|
||||||
|
|
||||||
input = tokenizer.encode(sequence, return_tensors="pt")
|
input = tokenizer.encode(sequence, return_tensors="pt")
|
||||||
generated = model.generate(input, max_length=50)
|
generated = model.generate(input, max_length=50, do_sample=True)
|
||||||
|
|
||||||
resulting_string = tokenizer.decode(generated.tolist()[0])
|
resulting_string = tokenizer.decode(generated.tolist()[0])
|
||||||
print(resulting_string)
|
print(resulting_string)
|
||||||
@@ -432,14 +432,10 @@ to generate the tokens following the initial sequence in PyTorch, and creating a
|
|||||||
model = TFAutoModelWithLMHead.from_pretrained("gpt2")
|
model = TFAutoModelWithLMHead.from_pretrained("gpt2")
|
||||||
|
|
||||||
sequence = f"Hugging Face is based in DUMBO, New York City, and is"
|
sequence = f"Hugging Face is based in DUMBO, New York City, and is"
|
||||||
generated = tokenizer.encode(sequence)
|
input = tokenizer.encode(sequence, return_tensors="tf")
|
||||||
|
generated = model.generate(input, max_length=50, do_sample=True)
|
||||||
|
|
||||||
for i in range(50):
|
resulting_string = tokenizer.decode(generated.tolist()[0])
|
||||||
predictions = model(tf.constant([generated]))[0]
|
|
||||||
token = tf.argmax(predictions[0], axis=1)[-1].numpy()
|
|
||||||
generated += [token]
|
|
||||||
|
|
||||||
resulting_string = tokenizer.decode(generated)
|
|
||||||
print(resulting_string)
|
print(resulting_string)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user