Source code for pyrit.memory.memory_embedding
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Optional
from pyrit.embedding import OpenAITextEmbedding
from pyrit.memory.memory_models import EmbeddingDataEntry
from pyrit.models import EmbeddingSupport, MessagePiece
[docs]
class MemoryEmbedding:
"""
The MemoryEmbedding class is responsible for encoding the memory embeddings.
Parameters:
embedding_model (EmbeddingSupport): An instance of a class that supports embedding generation.
"""
[docs]
def __init__(self, *, embedding_model: Optional[EmbeddingSupport] = None):
"""
Initialize the memory embedding helper with a backing embedding model.
Args:
embedding_model (Optional[EmbeddingSupport]): The embedding model used to
generate text embeddings. If not provided, a ValueError is raised.
Raises:
ValueError: If `embedding_model` is not provided.
"""
if embedding_model is None:
raise ValueError("embedding_model must be set.")
self.embedding_model = embedding_model
[docs]
def generate_embedding_memory_data(self, *, message_piece: MessagePiece) -> EmbeddingDataEntry:
"""
Generate metadata for a message piece.
Args:
message_piece (MessagePiece): the message piece for which to generate a text embedding
Returns:
EmbeddingDataEntry: The generated metadata.
Raises:
ValueError: If the message piece is not of type text.
"""
if message_piece.converted_value_data_type == "text":
embedding_response = self.embedding_model.generate_text_embedding(text=message_piece.converted_value)
embedding_data = EmbeddingDataEntry(
embedding=embedding_response.data[0].embedding,
embedding_type_name=self.embedding_model.__class__.__name__,
id=message_piece.id,
)
return embedding_data
raise ValueError("Only text data is supported for embedding.")
def default_memory_embedding_factory(embedding_model: Optional[EmbeddingSupport] = None) -> MemoryEmbedding | None:
"""
Create a MemoryEmbedding instance with default or provided embedding model.
Factory function that creates a MemoryEmbedding instance. If an embedding_model
is provided, it uses that model. Otherwise, it attempts to create an OpenAI
embedding model from environment variables.
Args:
embedding_model: Optional embedding model to use. If not provided,
attempts to create OpenAITextEmbedding from environment variables.
Returns:
MemoryEmbedding: Configured memory embedding instance.
Raises:
ValueError: If no embedding model is provided and required
OpenAI environment variables are not set.
"""
if embedding_model:
return MemoryEmbedding(embedding_model=embedding_model)
# Try to create OpenAI embedding model from environment variables
# The constructor will check for OPENAI_EMBEDDING_KEY, OPENAI_EMBEDDING_ENDPOINT, and OPENAI_EMBEDDING_MODEL
try:
model = OpenAITextEmbedding()
return MemoryEmbedding(embedding_model=model)
except ValueError:
raise ValueError("No embedding model was provided and no OpenAI embedding model was found in the environment.")