Source code for pyrit.prompt_target.http_target.http_target_callback_functions
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import re
from typing import Callable
import requests
[docs]
def get_http_target_json_response_callback_function(key: str) -> Callable:
"""
Determines proper parsing response function for an HTTP Request
Parameters:
key (str): this is the path pattern to follow for parsing the output response
(ie for AOAI this would be choices[0].message.content)
(for BIC this needs to be a regex pattern for the desired output)
response_type (ResponseType): this is the type of response (ie HTML or JSON)
Returns: proper output parsing response
"""
def parse_json_http_response(response: requests.Response):
"""
Parses JSON outputs
Parameters:
response (response): the HTTP Response to parse
Returns: parsed output from response given a "key" path to follow
"""
json_response = json.loads(response.content)
data_key = _fetch_key(data=json_response, key=key)
return data_key
return parse_json_http_response
[docs]
def get_http_target_regex_matching_callback_function(key: str, url: str = None) -> Callable:
def parse_using_regex(response: requests.Response):
"""
Parses text outputs using regex
Parameters:
url (optional str): the original URL if this is needed to get a full URL response back (ie BIC)
key (str): this is the regex pattern to follow for parsing the output response
response (response): the HTTP Response to parse
Returns: parsed output from response given a regex pattern to follow
"""
re_pattern = re.compile(key)
match = re.search(re_pattern, str(response.content))
if match:
if url:
return url + match.group()
else:
return match.group()
else:
return str(response.content)
return parse_using_regex
def _fetch_key(data: dict, key: str):
"""
Fetches the answer from the HTTP JSON response based on the path.
Args:
data (dict): HTTP response data.
key (str): The key path to fetch the value.
Returns:
str: The fetched value.
"""
pattern = re.compile(r"([a-zA-Z_]+)|\[(-?\d+)\]")
keys = pattern.findall(key)
for key_part, index_part in keys:
if key_part:
data = data.get(key_part, None)
elif index_part and isinstance(data, list):
data = data[int(index_part)] if -len(data) <= int(index_part) < len(data) else None
if data is None:
return ""
return data