Transformers serve VLM (#39454)
* Add support for VLMs in Transformers Serve * Raushan comments * Update src/transformers/commands/serving.py Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com> * Quick fix * CPU -> Auto * Update src/transformers/commands/serving.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Fixup --------- Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
import unittest
|
||||
from threading import Thread
|
||||
@@ -23,7 +24,7 @@ from parameterized import parameterized
|
||||
|
||||
import transformers.commands.transformers_cli as cli
|
||||
from transformers import GenerationConfig
|
||||
from transformers.commands.serving import ServeArguments, ServeCommand
|
||||
from transformers.commands.serving import Modality, ServeArguments, ServeCommand
|
||||
from transformers.testing_utils import CaptureStd, require_openai, slow
|
||||
from transformers.utils.import_utils import is_openai_available
|
||||
|
||||
@@ -258,6 +259,104 @@ class ServeCompletionsMixin:
|
||||
# TODO: speed-based test to confirm that KV cache is working across requests
|
||||
|
||||
|
||||
class ServeCompletionsGenerateMockTests(unittest.TestCase):
|
||||
def test_processor_inputs_from_inbound_messages_llm(self):
|
||||
modality = Modality.LLM
|
||||
messages = expected_outputs = [
|
||||
{"role": "user", "content": "How are you doing?"},
|
||||
{"role": "assistant", "content": "I'm doing great, thank you for asking! How can I assist you today?"},
|
||||
{"role": "user", "content": "Can you help me write tests?"},
|
||||
]
|
||||
outputs = ServeCommand.get_processor_inputs_from_inbound_messages(messages, modality)
|
||||
self.assertListEqual(expected_outputs, outputs)
|
||||
|
||||
def test_processor_inputs_from_inbound_messages_vlm_text_only(self):
|
||||
modality = Modality.VLM
|
||||
messages = [
|
||||
{"role": "user", "content": "How are you doing?"},
|
||||
{"role": "assistant", "content": "I'm doing great, thank you for asking! How can I assist you today?"},
|
||||
{"role": "user", "content": "Can you help me write tests?"},
|
||||
]
|
||||
|
||||
expected_outputs = [
|
||||
{"role": "user", "content": [{"type": "text", "text": "How are you doing?"}]},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "I'm doing great, thank you for asking! How can I assist you today?"}
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": [{"type": "text", "text": "Can you help me write tests?"}]},
|
||||
]
|
||||
|
||||
outputs = ServeCommand.get_processor_inputs_from_inbound_messages(messages, modality)
|
||||
self.assertListEqual(expected_outputs, outputs)
|
||||
|
||||
def test_processor_inputs_from_inbound_messages_vlm_text_and_image_in_base_64(self):
|
||||
modality = Modality.VLM
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "How many pixels are in the image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAASABIAAD/4QBARXhpZgAATU0AKgAAAAgAAYdpAAQAAAABAAAAGgAAAAAAAqACAAQAAAABAAAABaADAAQAAAABAAAABQAAAAD/7QA4UGhvdG9zaG9wIDMuMAA4QklNBAQAAAAAAAA4QklNBCUAAAAAABDUHYzZjwCyBOmACZjs+EJ+/8AAEQgABQAFAwEiAAIRAQMRAf/EAB8AAAEFAQEBAQEBAAAAAAAAAAABAgMEBQYHCAkKC//EALUQAAIBAwMCBAMFBQQEAAABfQECAwAEEQUSITFBBhNRYQcicRQygZGhCCNCscEVUtHwJDNicoIJChYXGBkaJSYnKCkqNDU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6g4SFhoeIiYqSk5SVlpeYmZqio6Slpqeoqaqys7S1tre4ubrCw8TFxsfIycrS09TV1tfY2drh4uPk5ebn6Onq8fLz9PX29/j5+v/EAB8BAAMBAQEBAQEBAQEAAAAAAAABAgMEBQYHCAkKC//EALURAAIBAgQEAwQHBQQEAAECdwABAgMRBAUhMQYSQVEHYXETIjKBCBRCkaGxwQkjM1LwFWJy0QoWJDThJfEXGBkaJicoKSo1Njc4OTpDREVGR0hJSlNUVVZXWFlaY2RlZmdoaWpzdHV2d3h5eoKDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uLj5OXm5+jp6vLz9PX29/j5+v/bAEMAAQEBAQEBAgEBAgICAgICAwICAgIDBAMDAwMDBAUEBAQEBAQFBQUFBQUFBQYGBgYGBgcHBwcHCAgICAgICAgICP/bAEMBAQEBAgICAwICAwgFBAUICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICP/dAAQAAf/aAAwDAQACEQMRAD8A/v4ooooA/9k="
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The number of pixels in the image cannot be determined from the provided information.",
|
||||
},
|
||||
{"role": "user", "content": "Alright"},
|
||||
]
|
||||
|
||||
expected_outputs = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "How many pixels are in the image?"},
|
||||
{"type": "image", "url": "/var/folders/4v/64sxdhsd3gz3r8vhhnyc0mqw0000gn/T/tmp50oyghk6.png"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "The number of pixels in the image cannot be determined from the provided information.",
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": [{"type": "text", "text": "Alright"}]},
|
||||
]
|
||||
|
||||
outputs = ServeCommand.get_processor_inputs_from_inbound_messages(messages, modality)
|
||||
|
||||
for expected_output, output in zip(expected_outputs, outputs):
|
||||
expected_output_content = expected_output["content"]
|
||||
output_content = output["content"]
|
||||
|
||||
self.assertEqual(type(expected_output_content), type(output_content))
|
||||
|
||||
if isinstance(expected_output_content, list):
|
||||
for expected_output_content_item, output_content_item in zip(expected_output_content, output_content):
|
||||
self.assertIn("type", expected_output_content_item)
|
||||
self.assertIn("type", output_content_item)
|
||||
self.assertTrue(expected_output_content_item["type"] == output_content_item["type"])
|
||||
|
||||
if expected_output_content_item["type"] == "text":
|
||||
self.assertEqual(expected_output_content_item["text"], output_content_item["text"])
|
||||
|
||||
if expected_output_content_item["type"] == "image":
|
||||
self.assertTrue(os.path.exists(output_content_item["url"]))
|
||||
else:
|
||||
raise ValueError("VLMs should only receive content as lists.")
|
||||
|
||||
|
||||
@slow # server startup time is slow on our push CI
|
||||
@require_openai
|
||||
class ServeCompletionsGenerateIntegrationTest(ServeCompletionsMixin, unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user