Split transformers chat and transformers serve (#38443)
* Next token * Split chat and serve * Support both generation methods * Style * Generation Config * temp * temp * Finalize serving.py Co-authored-by: =?UTF-8?q?c=C3=A9lina?= <hanouticelina@gmail.com> * Finalize chat.py * Update src/transformers/commands/serving.py Co-authored-by: célina <hanouticelina@gmail.com> * Lucain's comments Co-authored-by: Lucain <lucain@huggingface.co> * Update * Last comments on PR * Better error handling * Better error handling * CI errors * CI errors * Add tests * Fix tests * Fix tests * [chat] Split chat/serve (built on top of lysandre's PR) (#39031) * Next token * Split chat and serve * Support both generation methods * Style * Generation Config * temp * temp * Finalize serving.py Co-authored-by: =?UTF-8?q?c=C3=A9lina?= <hanouticelina@gmail.com> * Finalize chat.py * Update src/transformers/commands/serving.py Co-authored-by: célina <hanouticelina@gmail.com> * Lucain's comments Co-authored-by: Lucain <lucain@huggingface.co> * Update * Last comments on PR * Better error handling * Better error handling * CI errors * CI errors * Add tests * Fix tests * Fix tests * streaming tool call * abstract tool state; set tool start as eos * todos * server working on models without tools * rm chat's deprecated flags * chat defaults * kv cache persists across calls * add server docs * link * Update src/transformers/commands/serving.py * Apply suggestions from code review * i love merge conflicts * solve multi turn with tiny-agents * On the fly switching of the models * Remove required positional arg --------- Co-authored-by: Lysandre <hi@lysand.re> Co-authored-by: =?UTF-8?q?c=C3=A9lina?= <hanouticelina@gmail.com> Co-authored-by: Lucain <lucain@huggingface.co> * Protect names * Fix tests --------- Co-authored-by: =?UTF-8?q?c=C3=A9lina?= <hanouticelina@gmail.com> Co-authored-by: Lucain <lucain@huggingface.co> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
65
tests/commands/test_chat.py
Normal file
65
tests/commands/test_chat.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import transformers.commands.transformers_cli as cli
|
||||
from transformers.commands.chat import ChatArguments, ChatCommand
|
||||
from transformers.testing_utils import CaptureStd
|
||||
|
||||
|
||||
class ChatCLITest(unittest.TestCase):
|
||||
def test_help(self):
|
||||
with patch("sys.argv", ["transformers", "chat", "--help"]), CaptureStd() as cs:
|
||||
with self.assertRaises(SystemExit):
|
||||
cli.main()
|
||||
self.assertIn("chat interface", cs.out.lower())
|
||||
|
||||
@patch.object(ChatCommand, "run")
|
||||
def test_cli_dispatch(self, run_mock):
|
||||
args = ["transformers", "chat", "hf-internal-testing/tiny-random-gpt2"]
|
||||
with patch("sys.argv", args):
|
||||
cli.main()
|
||||
run_mock.assert_called_once()
|
||||
|
||||
def test_parsed_args(self):
|
||||
with (
|
||||
patch.object(ChatCommand, "__init__", return_value=None) as init_mock,
|
||||
patch.object(ChatCommand, "run") as run_mock,
|
||||
patch(
|
||||
"sys.argv",
|
||||
[
|
||||
"transformers",
|
||||
"chat",
|
||||
"test-model",
|
||||
"max_new_tokens=64",
|
||||
],
|
||||
),
|
||||
):
|
||||
cli.main()
|
||||
init_mock.assert_called_once()
|
||||
run_mock.assert_called_once()
|
||||
parsed_args = init_mock.call_args[0][0]
|
||||
self.assertEqual(parsed_args.model_name_or_path_or_address, "test-model")
|
||||
self.assertEqual(parsed_args.generate_flags, ["max_new_tokens=64"])
|
||||
|
||||
|
||||
class ChatUtilitiesTest(unittest.TestCase):
|
||||
def test_save_and_clear_chat(self):
|
||||
tmp_path = tempfile.mkdtemp()
|
||||
|
||||
args = ChatArguments(save_folder=str(tmp_path))
|
||||
args.model_name_or_path_or_address = "test-model"
|
||||
|
||||
chat_history = [{"role": "user", "content": "hi"}]
|
||||
filename = ChatCommand.save_chat(chat_history, args)
|
||||
self.assertTrue(os.path.isfile(filename))
|
||||
|
||||
cleared = ChatCommand.clear_chat_history()
|
||||
self.assertEqual(cleared, [])
|
||||
|
||||
def test_parse_generate_flags(self):
|
||||
dummy = ChatCommand.__new__(ChatCommand)
|
||||
parsed = ChatCommand.parse_generate_flags(dummy, ["temperature=0.5", "max_new_tokens=10"])
|
||||
self.assertEqual(parsed["temperature"], 0.5)
|
||||
self.assertEqual(parsed["max_new_tokens"], 10)
|
||||
34
tests/commands/test_serving.py
Normal file
34
tests/commands/test_serving.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import transformers.commands.transformers_cli as cli
|
||||
from transformers.commands.serving import ServeCommand
|
||||
from transformers.testing_utils import CaptureStd
|
||||
|
||||
|
||||
class ServeCLITest(unittest.TestCase):
|
||||
def test_help(self):
|
||||
with patch("sys.argv", ["transformers", "serve", "--help"]), CaptureStd() as cs:
|
||||
with self.assertRaises(SystemExit):
|
||||
cli.main()
|
||||
self.assertIn("serve", cs.out.lower())
|
||||
|
||||
def test_parsed_args(self):
|
||||
with (
|
||||
patch.object(ServeCommand, "__init__", return_value=None) as init_mock,
|
||||
patch.object(ServeCommand, "run") as run_mock,
|
||||
patch("sys.argv", ["transformers", "serve", "--host", "0.0.0.0", "--port", "9000"]),
|
||||
):
|
||||
cli.main()
|
||||
init_mock.assert_called_once()
|
||||
run_mock.assert_called_once()
|
||||
parsed_args = init_mock.call_args[0][0]
|
||||
self.assertEqual(parsed_args.host, "0.0.0.0")
|
||||
self.assertEqual(parsed_args.port, 9000)
|
||||
|
||||
def test_build_chunk(self):
|
||||
dummy = ServeCommand.__new__(ServeCommand)
|
||||
dummy.args = type("Args", (), {})()
|
||||
chunk = ServeCommand.build_chunk(dummy, "hello", "req0", finish_reason="stop")
|
||||
self.assertIn("chat.completion.chunk", chunk)
|
||||
self.assertIn("data:", chunk)
|
||||
Reference in New Issue
Block a user