Source code for pyrit.models.data_type_serializer
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import abc
import base64
import hashlib
import os
import time
from typing import Optional, TYPE_CHECKING, Union
from pathlib import Path
from mimetypes import guess_type
from urllib.parse import urlparse
from pyrit.models.literals import PromptDataType
if TYPE_CHECKING:
from pyrit.memory import MemoryInterface
[docs]
def data_serializer_factory(
*,
data_type: PromptDataType,
value: Optional[str] = None,
extension: Optional[str] = None,
):
if value:
if data_type == "text":
return TextDataTypeSerializer(prompt_text=value)
elif data_type == "image_path":
return ImagePathDataTypeSerializer(prompt_text=value, extension=extension)
elif data_type == "audio_path":
return AudioPathDataTypeSerializer(prompt_text=value, extension=extension)
elif data_type == "error":
return ErrorDataTypeSerializer(prompt_text=value)
elif data_type == "url":
return URLDataTypeSerializer(prompt_text=value)
else:
raise ValueError(f"Data type {data_type} not supported")
else:
if data_type == "image_path":
return ImagePathDataTypeSerializer(extension=extension)
elif data_type == "audio_path":
return AudioPathDataTypeSerializer(extension=extension)
elif data_type == "error":
return ErrorDataTypeSerializer(prompt_text="")
else:
raise ValueError(f"Data type {data_type} without prompt text not supported")
[docs]
class DataTypeSerializer(abc.ABC):
"""
Abstract base class for data type normalizers.
This class is responsible for saving multi-modal types to disk.
"""
data_type: PromptDataType
value: str
data_sub_directory: str
file_extension: str
_file_path: Union[Path, str] = None
@property
def _memory(self) -> MemoryInterface:
from pyrit.memory import CentralMemory
return CentralMemory.get_memory_instance()
[docs]
@abc.abstractmethod
def data_on_disk(self) -> bool:
"""
Returns True if the data is stored on disk.
"""
pass
[docs]
async def save_data(self, data: bytes) -> None:
"""
Saves the data to storage.
"""
file_path = await self.get_data_filename()
await self._memory.storage_io.write_file(file_path, data)
self.value = str(file_path)
[docs]
async def save_b64_image(self, data: str, output_filename: str = None) -> None:
"""
Saves the base64 encoded image to storage.
Arguments:
data: string with base64 data
output_filename (optional, str): filename to store image as. Defaults to UUID if not provided
"""
file_path: Union[Path, str] = None
if output_filename:
file_path = output_filename
else:
file_path = await self.get_data_filename()
image_bytes = base64.b64decode(data)
await self._memory.storage_io.write_file(file_path, image_bytes)
self.value = str(file_path)
[docs]
async def read_data(self) -> bytes:
"""
Reads the data from the storage.
"""
if not self.data_on_disk():
raise TypeError(f"Data for data Type {self.data_type} is not stored on disk")
if not self.value:
raise RuntimeError("Prompt text not set")
# Check if path exists
file_exists = await self._memory.storage_io.path_exists(path=self.value)
if not file_exists:
raise FileNotFoundError(f"File not found: {self.value}")
# Read the contents from the path
return await self._memory.storage_io.read_file(self.value)
[docs]
async def read_data_base64(self) -> str:
"""
Reads the data from the storage.
"""
byte_array = await self.read_data()
return base64.b64encode(byte_array).decode("utf-8")
[docs]
async def get_sha256(self) -> str:
input_bytes: bytes = None
if self.data_on_disk():
input_bytes = await self._memory.storage_io.read_file(self.value)
else:
if isinstance(self.value, str):
input_bytes = self.value.encode("utf-8")
else:
raise ValueError(f"Invalid data type {self.value}, expected str data type.")
hash_object = hashlib.sha256(input_bytes)
return hash_object.hexdigest()
[docs]
async def get_data_filename(self) -> Union[Path, str]:
"""
Generates or retrieves a unique filename for the data file.
"""
if self._file_path:
return self._file_path
if not self.data_on_disk():
raise TypeError("Data is not stored on disk")
if not self.data_sub_directory:
raise RuntimeError("Data sub directory not set")
ticks = int(time.time() * 1_000_000)
results_path = self._memory.results_path
if self.is_url(results_path):
full_data_directory_path = results_path + self.data_sub_directory
self._file_path = full_data_directory_path + f"/{ticks}.{self.file_extension}"
else:
full_data_directory_path = results_path + self.data_sub_directory
await self._memory.storage_io.create_directory_if_not_exists(Path(full_data_directory_path))
self._file_path = Path(full_data_directory_path, f"{ticks}.{self.file_extension}")
return self._file_path
[docs]
@staticmethod
def get_extension(file_path: str) -> str | None:
"""
Get the file extension from the file path.
"""
_, ext = os.path.splitext(file_path)
return ext if ext else None
[docs]
@staticmethod
def get_mime_type(file_path: str) -> str | None:
"""
Get the MIME type of the file path.
"""
mime_type, _ = guess_type(file_path)
return mime_type
[docs]
def is_url(self, path: str) -> bool:
"""
Helper function to check if a given path is a URL.
"""
return urlparse(path).scheme in ("http", "https")
[docs]
class TextDataTypeSerializer(DataTypeSerializer):
[docs]
def __init__(self, *, prompt_text: str):
self.data_type = "text"
self.value = prompt_text
[docs]
def data_on_disk(self) -> bool:
return False
[docs]
class ErrorDataTypeSerializer(DataTypeSerializer):
[docs]
def __init__(self, *, prompt_text: str):
self.data_type = "error"
self.value = prompt_text
[docs]
def data_on_disk(self) -> bool:
return False
class URLDataTypeSerializer(DataTypeSerializer):
def __init__(self, *, prompt_text: str):
self.data_type = "url"
self.value = prompt_text
def data_on_disk(self) -> bool:
return False
[docs]
class ImagePathDataTypeSerializer(DataTypeSerializer):
[docs]
def __init__(self, *, prompt_text: Optional[str] = None, extension: Optional[str] = None):
self.data_type = "image_path"
self.data_sub_directory = "/dbdata/images"
self.file_extension = extension if extension else "png"
if prompt_text:
self.value = prompt_text
[docs]
def data_on_disk(self) -> bool:
return True
[docs]
class AudioPathDataTypeSerializer(DataTypeSerializer):
[docs]
def __init__(
self,
*,
prompt_text: Optional[str] = None,
extension: Optional[str] = None,
):
self.data_type = "audio_path"
self.data_sub_directory = "/dbdata/audio"
self.file_extension = extension if extension else "mp3"
if prompt_text:
self.value = prompt_text
[docs]
def data_on_disk(self) -> bool:
return True