Source code for pyrit.models.embeddings

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

from abc import ABC, abstractmethod
from hashlib import sha256
from pathlib import Path

from pydantic import BaseModel, ConfigDict


[docs] class EmbeddingUsageInformation(BaseModel): """Token usage metadata returned by an embedding API.""" model_config = ConfigDict(extra="forbid") prompt_tokens: int total_tokens: int
[docs] class EmbeddingData(BaseModel): """Single embedding vector payload with index and object metadata.""" model_config = ConfigDict(extra="forbid") embedding: list[float] index: int object: str
[docs] class EmbeddingResponse(BaseModel): """Embedding API response containing vectors, model metadata, and usage.""" model_config = ConfigDict(extra="forbid") model: str object: str usage: EmbeddingUsageInformation data: list[EmbeddingData]
[docs] def save_to_file(self, directory_path: Path) -> str: """ Save the embedding response to disk and return the path of the new file. Args: directory_path (Path): The path to save the file to. Returns: str: The full path to the file that was saved. """ embedding_json = self.model_dump_json() embedding_hash = sha256(embedding_json.encode()).hexdigest() embedding_output_file_path = Path(directory_path, f"{embedding_hash}.json") embedding_output_file_path.write_text(embedding_json) return embedding_output_file_path.as_posix()
[docs] @staticmethod def load_from_file(file_path: Path) -> EmbeddingResponse: """ Load the embedding response from disk. Args: file_path (Path): The path to load the file from. Returns: EmbeddingResponse: The loaded embedding response. """ embedding_json_data = file_path.read_text(encoding="utf-8") return EmbeddingResponse.model_validate_json(embedding_json_data)
[docs] def to_json(self) -> str: """ Serialize this embedding response to JSON. Returns: str: JSON-encoded embedding response. """ return self.model_dump_json()
[docs] class EmbeddingSupport(ABC): """Protocol-like interface for classes that generate text embeddings."""
[docs] @abstractmethod def generate_text_embedding(self, text: str, **kwargs: object) -> EmbeddingResponse: """ Generate text embedding synchronously. Args: text: The text to generate the embedding for **kwargs: Additional arguments to pass to the function. Returns: The embedding response """ raise NotImplementedError("generate_text_embedding method not implemented")
[docs] @abstractmethod async def generate_text_embedding_async(self, text: str, **kwargs: object) -> EmbeddingResponse: """ Generate text embedding asynchronously. Args: text: The text to generate the embedding for **kwargs: Additional arguments to pass to the function. Returns: The embedding response """ raise NotImplementedError("generate_text_embedding_async method not implemented")