From c7f01beece5b25f05c910b130da654283805543d Mon Sep 17 00:00:00 2001 From: tom white Date: Mon, 9 Oct 2023 23:18:02 +1300 Subject: [PATCH] fix typos in idefics.md (#26648) * fix typos in idefics.md Two typos found in reviewing this documentation. 1) max_new_tokens=4, is not sufficient to generate "Vegetables" as indicated - you will get only "Veget". (incidentally - some mention of how to select this value might be useful as it seems to change in each example) 2) inputs = processor(prompts, return_tensors="pt").to(device) as inputs need to be on the same device (as they are in all other examples on the page) * Update idefics.md Change device to cuda explicitly to match other examples --- docs/source/en/tasks/idefics.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/tasks/idefics.md b/docs/source/en/tasks/idefics.md index 0e81efca12..376ec8b308 100644 --- a/docs/source/en/tasks/idefics.md +++ b/docs/source/en/tasks/idefics.md @@ -276,7 +276,7 @@ We can instruct the model to classify the image into one of the categories that >>> inputs = processor(prompt, return_tensors="pt").to("cuda") >>> bad_words_ids = processor.tokenizer(["", ""], add_special_tokens=False).input_ids ->>> generated_ids = model.generate(**inputs, max_new_tokens=4, bad_words_ids=bad_words_ids) +>>> generated_ids = model.generate(**inputs, max_new_tokens=6, bad_words_ids=bad_words_ids) >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) >>> print(generated_text[0]) Instruction: Classify the following image into a single category from the following list: ['animals', 'vegetables', 'city landscape', 'cars', 'office']. @@ -357,7 +357,7 @@ for a batch of examples by passing a list of prompts: ... ], ... ] ->>> inputs = processor(prompts, return_tensors="pt") +>>> inputs = processor(prompts, return_tensors="pt").to("cuda") >>> bad_words_ids = processor.tokenizer(["", ""], add_special_tokens=False).input_ids >>> generated_ids = model.generate(**inputs, max_new_tokens=10, bad_words_ids=bad_words_ids)