From 42e1e3c67fd0e7e5c9b58cf8e165df635da2c8e5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 31 Mar 2020 15:31:46 +0200 Subject: [PATCH] Update usage doc regarding generate fn (#3504) --- docs/source/usage.rst | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 8fb7a44727..6e53af1849 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -420,7 +420,7 @@ to generate the tokens following the initial sequence in PyTorch, and creating a sequence = f"Hugging Face is based in DUMBO, New York City, and is" input = tokenizer.encode(sequence, return_tensors="pt") - generated = model.generate(input, max_length=50) + generated = model.generate(input, max_length=50, do_sample=True) resulting_string = tokenizer.decode(generated.tolist()[0]) print(resulting_string) @@ -432,14 +432,10 @@ to generate the tokens following the initial sequence in PyTorch, and creating a model = TFAutoModelWithLMHead.from_pretrained("gpt2") sequence = f"Hugging Face is based in DUMBO, New York City, and is" - generated = tokenizer.encode(sequence) + input = tokenizer.encode(sequence, return_tensors="tf") + generated = model.generate(input, max_length=50, do_sample=True) - for i in range(50): - predictions = model(tf.constant([generated]))[0] - token = tf.argmax(predictions[0], axis=1)[-1].numpy() - generated += [token] - - resulting_string = tokenizer.decode(generated) + resulting_string = tokenizer.decode(generated.tolist()[0]) print(resulting_string)