Fix Mistral3 tests (#36797)
* fix processor tests * fix modeling tests * fix test processor chat template * revert modeling test changes
This commit is contained in:
@@ -54,7 +54,7 @@ Here is how you can use the `image-text-to-text` pipeline to perform inference w
|
|||||||
... },
|
... },
|
||||||
... ]
|
... ]
|
||||||
|
|
||||||
>>> pipe = pipeline("image-text-to-text", model="../mistral3_weights", torch_dtype=torch.bfloat16)
|
>>> pipe = pipeline("image-text-to-text", model="mistralai/Mistral-Small-3.1-24B-Instruct-2503", torch_dtype=torch.bfloat16)
|
||||||
>>> outputs = pipe(text=messages, max_new_tokens=50, return_full_text=False)
|
>>> outputs = pipe(text=messages, max_new_tokens=50, return_full_text=False)
|
||||||
>>> outputs[0]["generated_text"]
|
>>> outputs[0]["generated_text"]
|
||||||
'The image depicts a vibrant and lush garden scene featuring a variety of wildflowers and plants. The central focus is on a large, pinkish-purple flower, likely a Greater Celandine (Chelidonium majus), with a'
|
'The image depicts a vibrant and lush garden scene featuring a variety of wildflowers and plants. The central focus is on a large, pinkish-purple flower, likely a Greater Celandine (Chelidonium majus), with a'
|
||||||
|
|||||||
@@ -51,17 +51,75 @@ class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.tmpdirname = tempfile.mkdtemp()
|
self.tmpdirname = tempfile.mkdtemp()
|
||||||
processor = PixtralProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
processor = self.processor_class.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503")
|
||||||
processor.save_pretrained(self.tmpdirname)
|
processor.save_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
|
def get_processor(self):
|
||||||
|
return self.processor_class.from_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
shutil.rmtree(self.tmpdirname)
|
shutil.rmtree(self.tmpdirname)
|
||||||
|
|
||||||
def test_chat_template(self):
|
def test_chat_template_accepts_processing_kwargs(self):
|
||||||
processor = self.processor_class.from_pretrained(self.tmpdirname)
|
# override to use slow image processor to return numpy arrays
|
||||||
expected_prompt = "<s>[INST][IMG]What is shown in this image?[/INST]"
|
processor = self.processor_class.from_pretrained(self.tmpdirname, use_fast=False)
|
||||||
|
if processor.chat_template is None:
|
||||||
|
self.skipTest("Processor has no chat template")
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
formatted_prompt_tokenized = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
max_length=50,
|
||||||
|
)
|
||||||
|
self.assertEqual(len(formatted_prompt_tokenized[0]), 50)
|
||||||
|
|
||||||
|
formatted_prompt_tokenized = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=5,
|
||||||
|
)
|
||||||
|
self.assertEqual(len(formatted_prompt_tokenized[0]), 5)
|
||||||
|
|
||||||
|
# Now test the ability to return dict
|
||||||
|
messages[0][0]["content"].append(
|
||||||
|
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}
|
||||||
|
)
|
||||||
|
out_dict = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
do_rescale=True,
|
||||||
|
rescale_factor=-1,
|
||||||
|
return_tensors="np",
|
||||||
|
)
|
||||||
|
self.assertLessEqual(out_dict[self.images_input_name][0][0].mean(), 0)
|
||||||
|
|
||||||
|
def test_chat_template(self):
|
||||||
|
processor = self.processor_class.from_pretrained(self.tmpdirname, use_fast=False)
|
||||||
|
expected_prompt = "<s>[SYSTEM_PROMPT][/SYSTEM_PROMPT][INST][IMG]What is shown in this image?[/INST]"
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
@@ -81,6 +139,10 @@ class Mistral3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
image_token_index = 10
|
image_token_index = 10
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
|
|||||||
Reference in New Issue
Block a user