Source code for pyrit.models.embeddings
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from abc import ABC, abstractmethod
from hashlib import sha256
from pathlib import Path
from pydantic import BaseModel, ConfigDict
[docs]
class EmbeddingData(BaseModel):
model_config = ConfigDict(extra="forbid")
embedding: list[float]
index: int
object: str
[docs]
class EmbeddingResponse(BaseModel):
model_config = ConfigDict(extra="forbid")
model: str
object: str
usage: EmbeddingUsageInformation
data: list[EmbeddingData]
[docs]
def save_to_file(self, directory_path: Path) -> str:
"""Save the embedding response to disk and return the path of the new file
Args:
directory_path: The path to save the file to
Returns:
The full path to the file that was saved
"""
embedding_json = self.json()
embedding_hash = sha256(embedding_json.encode()).hexdigest()
embedding_output_file_path = Path(directory_path, f"{embedding_hash}.json")
embedding_output_file_path.write_text(embedding_json)
return embedding_output_file_path.as_posix()
[docs]
@staticmethod
def load_from_file(file_path: Path) -> EmbeddingResponse:
"""Load the embedding response from disk
Args:
file_path: The path to load the file from
Returns:
The loaded embedding response
"""
embedding_json_data = file_path.read_text(encoding="utf-8")
return EmbeddingResponse.model_validate_json(embedding_json_data)
[docs]
def to_json(self) -> str:
return self.model_dump_json()
[docs]
class EmbeddingSupport(ABC):
[docs]
@abstractmethod
def generate_text_embedding(self, text: str, **kwargs) -> EmbeddingResponse:
"""Generate text embedding
Args:
text: The text to generate the embedding for
**kwargs: Additional arguments to pass to the function.
Returns:
The embedding response
"""
raise NotImplementedError("generate_text_embedding method not implemented")