Source code for pyrit.memory.memory_embedding
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
from typing import Optional
from pyrit.embedding import AzureTextEmbedding
from pyrit.models import PromptRequestPiece, EmbeddingSupport
from pyrit.memory.memory_models import EmbeddingDataEntry
[docs]
class MemoryEmbedding:
"""
The MemoryEmbedding class is responsible for encoding the memory embeddings.
Parameters:
embedding_model (EmbeddingSupport): An instance of a class that supports embedding generation.
"""
[docs]
def __init__(self, *, embedding_model: Optional[EmbeddingSupport]):
if embedding_model is None:
raise ValueError("embedding_model must be set.")
self.embedding_model = embedding_model
[docs]
def generate_embedding_memory_data(self, *, prompt_request_piece: PromptRequestPiece) -> EmbeddingDataEntry:
"""
Generates metadata for a chat memory entry.
Args:
chat_memory (ConversationData): The chat memory entry.
Returns:
ConversationMemoryEntryMetadata: The generated metadata.
"""
if prompt_request_piece.converted_value_data_type == "text":
embedding_data = EmbeddingDataEntry(
embedding=self.embedding_model.generate_text_embedding(text=prompt_request_piece.converted_value)
.data[0]
.embedding,
embedding_type_name=self.embedding_model.__class__.__name__,
id=prompt_request_piece.id,
)
return embedding_data
raise ValueError("Only text data is supported for embedding.")
def default_memory_embedding_factory(embedding_model: Optional[EmbeddingSupport] = None) -> MemoryEmbedding | None:
if embedding_model:
return MemoryEmbedding(embedding_model=embedding_model)
api_key = os.environ.get("AZURE_OPENAI_EMBEDDING_KEY")
api_base = os.environ.get("AZURE_OPENAI_EMBEDDING_ENDPOINT")
deployment = os.environ.get("AZURE_OPENAI_EMBEDDING_DEPLOYMENT")
if api_key and api_base and deployment:
model = AzureTextEmbedding(api_key=api_key, endpoint=api_base, deployment=deployment)
return MemoryEmbedding(embedding_model=model)
else:
raise ValueError(
"No embedding model was provided and no Azure OpenAI embedding model was found in the environment."
)