Source code for pyrit.prompt_converter.charswap_attack_converter
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import math
import random
import re
import string
from pyrit.models import PromptDataType
from pyrit.prompt_converter import ConverterResult, PromptConverter
# Use logger
logger = logging.getLogger(__name__)
[docs]
class CharSwapGenerator(PromptConverter):
"""
A PromptConverter that applies character swapping to words in the prompt
to test adversarial textual robustness.
"""
[docs]
def __init__(self, *, max_iterations: int = 10, word_swap_ratio: float = 0.2):
"""
Initializes the CharSwapConverter.
Args:
max_iterations (int): Number of times to generate perturbed prompts.
The higher the number the higher the chance that words are different from the original prompt.
word_swap_ratio (float): Percentage of words to perturb in the prompt per iteration.
"""
super().__init__()
# Ensure max_iterations is positive
if max_iterations <= 0:
raise ValueError("max_iterations must be greater than 0")
# Ensure word_swap_ratio is between 0 and 1
if not (0 < word_swap_ratio <= 1):
raise ValueError("word_swap_ratio must be between 0 and 1 (exclusive of 0)")
self.max_iterations = max_iterations
self.word_swap_ratio = word_swap_ratio
[docs]
def output_supported(self, output_type: PromptDataType) -> bool:
return output_type == "text"
def _perturb_word(self, word: str) -> str:
"""
Perturb a word by swapping two adjacent characters.
Args:
word (str): The word to perturb.
Returns:
str: The perturbed word with swapped characters.
"""
if word not in string.punctuation and len(word) > 3:
idx1 = random.randint(1, len(word) - 2)
idx_elements = list(word)
# Swap characters
idx_elements[idx1], idx_elements[idx1 + 1] = (
idx_elements[idx1 + 1],
idx_elements[idx1],
)
return "".join(idx_elements)
return word
[docs]
async def convert_async(self, *, prompt: str, input_type="text") -> ConverterResult:
"""
Converts the given prompt by applying character swaps.
Args:
prompt (str): The prompt to be converted.
Returns:
ConverterResult: The result containing the perturbed prompts.
"""
if not self.input_supported(input_type):
raise ValueError("Input type not supported")
# Tokenize the prompt into words and punctuation using regex
words = re.findall(r"\w+|\S+", prompt)
word_list_len = len(words)
num_perturb_words = max(1, math.ceil(word_list_len * self.word_swap_ratio))
# Copy the original word list for perturbation
perturbed_word_list = words.copy()
# Get random indices of words to undergo swapping
random_words_idx = self._get_n_random(0, word_list_len, num_perturb_words)
# Apply perturbation by swapping characters in the selected words
for idx in random_words_idx:
perturbed_word_list[idx] = self._perturb_word(perturbed_word_list[idx])
# Join the perturbed words back into a prompt
new_prompt = " ".join(perturbed_word_list)
# Clean up spaces around punctuation
output_text = re.sub(r'\s([?.!,\'"])', r"\1", new_prompt).strip()
return ConverterResult(output_text=output_text, output_type="text")
def _get_n_random(self, low: int, high: int, n: int) -> list:
"""
Utility function to generate random indices.
Words at these indices will be subjected to perturbation.
"""
result = []
try:
result = random.sample(range(low, high), n)
except ValueError:
logger.debug(f"[CharSwapConverter] Sample size of {n} exceeds population size of {high - low}")
return result