Source code for pyrit.chat_message_normalizer.chat_message_normalizer_chatml

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import re
from pyrit.models import ChatMessage, ChatMessageRole, ALLOWED_CHAT_MESSAGE_ROLES
from pyrit.chat_message_normalizer import ChatMessageNormalizer
from typing import cast


[docs] class ChatMessageNormalizerChatML(ChatMessageNormalizer[str]):
[docs] def normalize(self, messages: list[ChatMessage]) -> str: """Convert a string of text to a ChatML string. This is compliant with the ChatML specified in https://github.com/openai/openai-python/blob/release-v0.28.0/chatml.md """ final_string: str = "" final_string = "" for m in messages: final_string += f"<|im_start|>{m.role}{f' name={m.name}' if m.name else ''}\n{m.content}<|im_end|>\n" return final_string
[docs] @staticmethod def from_chatml(content: str) -> list[ChatMessage]: """Convert a chatML string to a list of chat messages""" messages: list[ChatMessage] = [] matches = list(re.finditer(r"<\|im_start\|>(.*?)<\|im_end\|>", content, re.DOTALL | re.MULTILINE)) if not matches: raise ValueError("No chat messages found in the chatML string") for match in matches: lines = match.group(1).split("\n") role_line = lines[0].strip() role_match = re.match(r"(?P<role>\w+)( name=(?P<name>\w+))?", role_line) name = role_match.group("name") if role_match else None role = role_match.group("role") if role not in ALLOWED_CHAT_MESSAGE_ROLES: raise ValueError(f"Role {role} is not allowed in chatML") content = lines[1].strip() messages.append(ChatMessage(role=cast(ChatMessageRole, role), content=content, name=name)) return messages