Source code for pyrit.memory.memory_embedding
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
from typing import Optional
from pyrit.embedding import AzureTextEmbedding
from pyrit.memory.memory_models import EmbeddingDataEntry
from pyrit.models import EmbeddingSupport, MessagePiece
[docs]
class MemoryEmbedding:
    """
    The MemoryEmbedding class is responsible for encoding the memory embeddings.
    Parameters:
        embedding_model (EmbeddingSupport): An instance of a class that supports embedding generation.
    """
[docs]
    def __init__(self, *, embedding_model: Optional[EmbeddingSupport] = None):
        if embedding_model is None:
            raise ValueError("embedding_model must be set.")
        self.embedding_model = embedding_model 
[docs]
    def generate_embedding_memory_data(self, *, message_piece: MessagePiece) -> EmbeddingDataEntry:
        """
        Generates metadata for a message piece.
        Args:
            message_piece (MessagePiece): the message piece for which to generate a text embedding
        Returns:
            EmbeddingDataEntry: The generated metadata.
        """
        if message_piece.converted_value_data_type == "text":
            embedding_data = EmbeddingDataEntry(
                embedding=self.embedding_model.generate_text_embedding(text=message_piece.converted_value)
                .data[0]
                .embedding,
                embedding_type_name=self.embedding_model.__class__.__name__,
                id=message_piece.id,
            )
            return embedding_data
        raise ValueError("Only text data is supported for embedding.") 
 
def default_memory_embedding_factory(embedding_model: Optional[EmbeddingSupport] = None) -> MemoryEmbedding | None:
    if embedding_model:
        return MemoryEmbedding(embedding_model=embedding_model)
    api_key = os.environ.get("AZURE_OPENAI_EMBEDDING_KEY")
    api_base = os.environ.get("AZURE_OPENAI_EMBEDDING_ENDPOINT")
    deployment = os.environ.get("AZURE_OPENAI_EMBEDDING_DEPLOYMENT")
    if api_key and api_base and deployment:
        model = AzureTextEmbedding(api_key=api_key, endpoint=api_base, deployment=deployment)
        return MemoryEmbedding(embedding_model=model)
    else:
        raise ValueError(
            "No embedding model was provided and no Azure OpenAI embedding model was found in the environment."
        )