Source code for pyrit.chat_message_normalizer.chat_message_normalizer_tokenizer
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from pyrit.chat_message_normalizer import ChatMessageNormalizer
from pyrit.models import ChatMessage
[docs]
class ChatMessageNormalizerTokenizerTemplate(ChatMessageNormalizer[str]):
"""
This class enables you to apply the chat template stored in a Hugging Face tokenizer
to a list of chat messages. For more details, see
https://huggingface.co/docs/transformers/main/en/chat_templating
"""
[docs]
def __init__(self, tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast):
"""
Initializes an instance of the ChatMessageNormalizerTokenizerTemplate class.
Args:
tokenizer (PreTrainedTokenizer | PreTrainedTokenizerFast): A Hugging Face tokenizer.
"""
self.tokenizer = tokenizer
[docs]
def normalize(self, messages: list[ChatMessage]) -> str:
"""
Applies the chat template stored in the tokenizer to a list of chat messages.
Args:
messages (list[ChatMessage]): A list of ChatMessage objects.
Returns:
str: The formatted chat messages.
"""
messages_list = []
formatted_messages: str = ""
for m in messages:
messages_list.append({"role": m.role, "content": m.content})
formatted_messages = self.tokenizer.apply_chat_template(
messages_list,
tokenize=False,
add_generation_prompt=True,
)
return formatted_messages