Source code for pyrit.prompt_converter.charswap_attack_converter

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import random
import re
import string
from typing import List, Optional, Union

from pyrit.prompt_converter.word_level_converter import WordLevelConverter


[docs] class CharSwapConverter(WordLevelConverter): """ Applies character swapping to words in the prompt to test adversarial textual robustness. """
[docs] def __init__( self, *, max_iterations: int = 10, indices: Optional[List[int]] = None, keywords: Optional[List[str]] = None, proportion: Optional[float] = 0.2, regex: Optional[Union[str, re.Pattern]] = None, ): """ Initializes the converter with the specified parameters. This class allows for selection of words to convert based on various criteria. Only one selection parameter may be provided at a time (indices, keywords, proportion, or regex). By default, proportion is set to 0.2, meaning 20% of randomly selected words will be perturbed. Args: max_iterations (int): Number of times to generate perturbed prompts. The higher the number the higher the chance that words are different from the original prompt. indices (Optional[List[int]]): Specific indices of words to convert. keywords (Optional[List[str]]): Keywords to select words for conversion. proportion (Optional[float]): Proportion of randomly selected words to convert [0.0-1.0]. regex (Optional[Union[str, re.Pattern]]): Regex pattern to match words for conversion. """ super().__init__(indices=indices, keywords=keywords, proportion=proportion, regex=regex) # Ensure max_iterations is positive if max_iterations <= 0: raise ValueError("max_iterations must be greater than 0") self.max_iterations = max_iterations
[docs] async def convert_word_async(self, word: str) -> str: return self._perturb_word(word)
def _perturb_word(self, word: str) -> str: """ Perturbs a word by swapping two adjacent characters. Args: word (str): The word to perturb. Returns: str: The perturbed word with swapped characters. """ if word not in string.punctuation and len(word) > 3: idx1 = random.randint(1, len(word) - 2) idx_elements = list(word) # Swap characters idx_elements[idx1], idx_elements[idx1 + 1] = ( idx_elements[idx1 + 1], idx_elements[idx1], ) return "".join(idx_elements) return word