AML Connect
Helper code for connecting to AzureML and sharing one workspace accross code.
add_cli_args(parser)
Adds parser arguments for connecting to AzureML
Parameters:
Name | Type | Description | Default |
---|---|---|---|
parser |
argparse.ArgumentParser |
parser to add AzureML arguments to |
required |
Returns:
Type | Description |
---|---|
argparse.ArgumentParser |
that same parser |
Source code in shrike/pipeline/aml_connect.py
def add_cli_args(parser):
"""Adds parser arguments for connecting to AzureML
Args:
parser (argparse.ArgumentParser): parser to add AzureML arguments to
Returns:
argparse.ArgumentParser: that same parser
"""
parser.add_argument(
"--aml-subscription-id",
dest="aml_subscription_id",
type=str,
required=False,
help="",
)
parser.add_argument(
"--aml-resource-group",
dest="aml_resource_group",
type=str,
required=False,
help="",
)
parser.add_argument(
"--aml-workspace", dest="aml_workspace_name", type=str, required=False, help=""
)
parser.add_argument(
"--aml-config",
dest="aml_config",
type=str,
required=False,
help="path to aml config.json file",
)
parser.add_argument(
"--aml-auth",
dest="aml_auth",
type=str,
choices=["azurecli", "msi", "interactive"],
default="interactive",
)
parser.add_argument(
"--aml-tenant",
dest="aml_tenant",
type=str,
default=None,
help="tenant to use for auth (default: auto)",
)
parser.add_argument(
"--aml-force",
dest="aml_force",
type=lambda x: (
str(x).lower() in ["true", "1", "yes"]
), # we want to use --aml-force True
default=False,
help="force tenant auth (default: False)",
)
return parser
azureml_connect(**kwargs)
Calls azureml_connect_cli with an argparse-like structure based on keyword arguments
Source code in shrike/pipeline/aml_connect.py
def azureml_connect(**kwargs):
"""Calls azureml_connect_cli with an argparse-like structure
based on keyword arguments"""
keys = [
"aml_subscription_id",
"aml_resource_group",
"aml_workspace_name",
"aml_config",
"aml_auth",
"aml_tenant",
"aml_force",
]
aml_args = dict([(k, kwargs.get(k)) for k in keys])
azureml_argparse_tuple = namedtuple("AzureMLArguments", aml_args)
aml_argparse = azureml_argparse_tuple(**aml_args)
return azureml_connect_cli(aml_argparse)
azureml_connect_cli(args)
Connects to an AzureML workspace.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
args |
argparse.Namespace |
arguments to connect to AzureML |
required |
Returns:
Type | Description |
---|---|
azureml.core.Workspace |
AzureML workspace |
Source code in shrike/pipeline/aml_connect.py
def azureml_connect_cli(args):
"""Connects to an AzureML workspace.
Args:
args (argparse.Namespace): arguments to connect to AzureML
Returns:
azureml.core.Workspace: AzureML workspace
"""
if args.aml_auth == "msi":
from azureml.core.authentication import MsiAuthentication
auth = MsiAuthentication()
elif args.aml_auth == "azurecli":
from azureml.core.authentication import AzureCliAuthentication
auth = AzureCliAuthentication()
elif args.aml_auth == "interactive":
from azureml.core.authentication import InteractiveLoginAuthentication
auth = InteractiveLoginAuthentication(
tenant_id=args.aml_tenant, force=args.aml_force
)
else:
auth = None
if args.aml_config:
config_dir = os.path.dirname(args.aml_config)
config_file_name = os.path.basename(args.aml_config)
aml_ws = Workspace.from_config(
path=config_dir, _file_name=config_file_name, auth=auth
)
else:
aml_ws = Workspace.get(
subscription_id=args.aml_subscription_id,
name=args.aml_workspace_name,
resource_group=args.aml_resource_group,
auth=auth,
)
log.info("Connected to workspace:")
log.info(f"\tsubscription: {aml_ws.subscription_id}")
log.info(f"\tname: {aml_ws.name}")
log.info(f"\tAzure region: {aml_ws.location}")
log.info(f"\tresource group: {aml_ws.resource_group}")
return current_workspace(aml_ws)
current_workspace(workspace=None)
Sets/Gets the current AML workspace used all accross code.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
workspace |
azureml.core.Workspace |
any given workspace |
None |
Returns:
Type | Description |
---|---|
azureml.core.Workspace |
current (last) workspace given to current_workspace() |
Source code in shrike/pipeline/aml_connect.py
def current_workspace(workspace=None):
"""Sets/Gets the current AML workspace used all accross code.
Args:
workspace (azureml.core.Workspace): any given workspace
Returns:
azureml.core.Workspace: current (last) workspace given to current_workspace()
"""
global CURRENT_AML_WORKSPACE
if workspace:
CURRENT_AML_WORKSPACE = workspace
if not CURRENT_AML_WORKSPACE:
raise Exception(
"You need to initialize current_workspace() with an AML workspace"
)
return CURRENT_AML_WORKSPACE
main()
Main function (for testing)
Source code in shrike/pipeline/aml_connect.py
def main():
"""Main function (for testing)"""
parser = argparse.ArgumentParser(description=__doc__)
group = parser.add_argument_group("AzureML connect arguments")
add_cli_args(group)
args, unknown_args = parser.parse_known_args()
if unknown_args:
log.warning(f"You have provided unknown arguments {unknown_args}")
return azureml_connect_cli(args)