Chat template: update for processor (#35953)

* update

* we need batched nested input to always process correctly

* update a bit

* fix copies
This commit is contained in:
Raushan Turganbay
2025-02-10 09:52:19 +01:00
committed by GitHub
parent 5bd7694781
commit eebd2c972c
21 changed files with 966 additions and 111 deletions

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import shutil
import tempfile
import unittest
@@ -52,6 +53,20 @@ class MllamaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def prepare_processor_dict(self):
return {"chat_template": "{% for message in messages %}{% if loop.index0 == 0 %}{{ bos_token }}{% endif %}{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}{% if message['content'] is string %}{{ message['content'] }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' %}{{ '<|image|>' }}{% elif content['type'] == 'text' %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{ '<|eot_id|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"} # fmt: skip
def test_chat_template_is_saved(self):
processor_loaded = self.processor_class.from_pretrained(self.tmpdirname)
processor_dict_loaded = json.loads(processor_loaded.to_json_string())
# chat templates aren't serialized to json in processors
self.assertFalse("chat_template" in processor_dict_loaded.keys())
# they have to be saved as separate file and loaded back from that file
# so we check if the same template is loaded
processor_dict = self.prepare_processor_dict()
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))
def test_apply_chat_template(self):
# Message contains content which a mix of lists with images and image urls and string
messages = [