Source code for pyrit.prompt_converter.charswap_attack_converter

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import math
import random
import re
import string

from pyrit.models import PromptDataType
from pyrit.prompt_converter import ConverterResult, PromptConverter

# Use logger
logger = logging.getLogger(__name__)


[docs] class CharSwapGenerator(PromptConverter): """ A PromptConverter that applies character swapping to words in the prompt to test adversarial textual robustness. """
[docs] def __init__(self, *, max_iterations: int = 10, word_swap_ratio: float = 0.2): """ Initializes the CharSwapConverter. Args: max_iterations (int): Number of times to generate perturbed prompts. The higher the number the higher the chance that words are different from the original prompt. word_swap_ratio (float): Percentage of words to perturb in the prompt per iteration. """ super().__init__() # Ensure max_iterations is positive if max_iterations <= 0: raise ValueError("max_iterations must be greater than 0") # Ensure word_swap_ratio is between 0 and 1 if not (0 < word_swap_ratio <= 1): raise ValueError("word_swap_ratio must be between 0 and 1 (exclusive of 0)") self.max_iterations = max_iterations self.word_swap_ratio = word_swap_ratio
[docs] def input_supported(self, input_type: PromptDataType) -> bool: return input_type == "text"
[docs] def output_supported(self, output_type: PromptDataType) -> bool: return output_type == "text"
def _perturb_word(self, word: str) -> str: """ Perturb a word by swapping two adjacent characters. Args: word (str): The word to perturb. Returns: str: The perturbed word with swapped characters. """ if word not in string.punctuation and len(word) > 3: idx1 = random.randint(1, len(word) - 2) idx_elements = list(word) # Swap characters idx_elements[idx1], idx_elements[idx1 + 1] = ( idx_elements[idx1 + 1], idx_elements[idx1], ) return "".join(idx_elements) return word
[docs] async def convert_async(self, *, prompt: str, input_type="text") -> ConverterResult: """ Converts the given prompt by applying character swaps. Args: prompt (str): The prompt to be converted. Returns: ConverterResult: The result containing the perturbed prompts. """ if not self.input_supported(input_type): raise ValueError("Input type not supported") # Tokenize the prompt into words and punctuation using regex words = re.findall(r"\w+|\S+", prompt) word_list_len = len(words) num_perturb_words = max(1, math.ceil(word_list_len * self.word_swap_ratio)) # Copy the original word list for perturbation perturbed_word_list = words.copy() # Get random indices of words to undergo swapping random_words_idx = self._get_n_random(0, word_list_len, num_perturb_words) # Apply perturbation by swapping characters in the selected words for idx in random_words_idx: perturbed_word_list[idx] = self._perturb_word(perturbed_word_list[idx]) # Join the perturbed words back into a prompt new_prompt = " ".join(perturbed_word_list) # Clean up spaces around punctuation output_text = re.sub(r'\s([?.!,\'"])', r"\1", new_prompt).strip() return ConverterResult(output_text=output_text, output_type="text")
def _get_n_random(self, low: int, high: int, n: int) -> list: """ Utility function to generate random indices. Words at these indices will be subjected to perturbation. """ result = [] try: result = random.sample(range(low, high), n) except ValueError: logger.debug(f"[CharSwapConverter] Sample size of {n} exceeds population size of {high - low}") return result