Logging metrics in Azure ML Portal using shrike
Logging real-time metrics and sending them to Azure Machine Learning (ML) workspace portal
is supported by shrike >= 1.7.0
. This page is on how to use the shrike
library to log various types of metrics in AML. For the information on how to use the
Azure ML Python SDK for metrics logging, please check out this documentation.
This metrics logging feature in shrike is supported in both eyes-on and eyes-off environments, and it also works in the offline runs (e.g., your local laptop or any non-AML virtual machines). For jobs in detonation chambers, please check out this internal note.
Before logging any metrics, call shrike.compliant_logging.enable_compliant_logging
with the argument use_aml_metrics=True
and category=DataCategory.PUBLIC
to connect with
the workspace portal and set up data-category-aware logging.
Then continue to use the standard Python logging functionality as before.
There are various metric-logging functions to match various metric types.
The API references on all metric-logging functions are availabe
on this page.
Use the following methods in the logging APIs for different scenarios & metric types.
Logged Value | Example Code | Supported Types |
---|---|---|
Log image | log.metric_image(name='food', path='./breadpudding.jpg', plot=None, description='desert', category=DataCategory.PUBLIC) |
string, matplotlib.pyplot.plot |
Log an array of numeric values | log.metric_list(name="Fibonacci", value=[0, 1, 1, 2, 3, 5, 8], category=DataCategory.PUBLIC) |
list, tuple |
Log a single value | log.metric_value(name='metric_value', value=1, step=NA, category=DataCategory.PUBLIC)) |
scalar |
Log a row with 2 numerical columns | log.metric_row(name='Cosine Wave', angle=0, cos=1, category=DataCategory.PUBLIC)) |
scalar |
Log a table | log.metric_table(name="students", value={"name": ["James", "Robert", "Michael"], "number": [2, 3, 1, 5]}, category=DataCategory.PUBLIC) |
dict |
Log residuals | log.metric_residual(name="ml_residual", value=panda.DataFrame([[1.0, 1.1], [2.0, 2.0], [3.0, 3.1]],columns=["pred", "targ"]), col_predict="pred", col_target="targ", category=DataCategory.PUBLIC) |
dict, pandas.DataFrame, vaex.dataframe, spark dataframe |
Log confusion matrix | log.metric_confusion_matrix(name="animal_classification", value=vaex.from_arrays(x=["cat", "ant", "cat", "cat", "ant", "bird"], y=["ant", "ant", "cat", "cat", "ant", "cat"]), idx_true="x", idx_pred="y",category=DataCategory.PUBLI) |
dict, pandas.DataFrame, vaex.dataframe, spark dataframe |
Log accuracy table | log.metric_confusion_matrix(name="accuracy_table", value=vaex.from_arrays(x=[0.1, 0.3, 0.7], y=["a", "b", "c"]), idx_true="x", idx_pred="y",category=DataCategory.PUBLI) |
dict, pandas.DataFrame, vaex.dataframe, spark dataframe |
Here is a full-fledged example:
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
import logging
import random
import matplotlib.pyplot as plt
from shrike.compliant_logging import enable_compliant_logging
from shrike.compliant_logging.constants import DataCategory
def run(args):
n = args.list_length
list1 = [random.randint(0, 100) for i in range(n)]
list2 = [random.randint(0, 100) for i in range(n)]
log = logging.getLogger(__name__)
log.info(
"Start metric logging in azure ml workspace portal",
category=DataCategory.PUBLIC,
)
# log list
log.metric_list(name="list1", value=list1, category=DataCategory.PUBLIC)
# log table
log.metric_table(
name="Lists",
value={"list1": list1, "list2": list2},
category=DataCategory.PUBLIC,
)
# log scalar value
log.metric(name="sum1", value=sum(list1), category=DataCategory.PUBLIC)
log.metric(name="sum2", value=sum(list2), category=DataCategory.PUBLIC)
# log image
plt.plot(list1, list2)
log.metric_image(name="Sample plot", plot=plt, category=DataCategory.PUBLIC)
# log row
for i in range(n):
log.metric_row(
name="pairwise-sum",
description="",
category=DataCategory.PUBLIC,
pairwise_sum=list1[i] + list2[i],
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--prefix", default="SystemLog:")
parser.add_argument("--log_level", default="INFO")
parser.add_argument(
"--list_length",
required=False,
default=5,
type=int,
help="length of test list",
)
args = parser.parse_args()
enable_compliant_logging(
args.prefix,
level=args.log_level,
format="%(prefix)s%(levelname)s:%(name)s:%(message)s",
use_aml_metrics=True,
)
run(args)