Skip to content

Logging metrics in Azure ML Portal using shrike

Logging real-time metrics and sending them to Azure Machine Learning (ML) workspace portal is supported by shrike >= 1.7.0. This page is on how to use the shrike library to log various types of metrics in AML. For the information on how to use the Azure ML Python SDK for metrics logging, please check out this documentation.

This metrics logging feature in shrike is supported in both eyes-on and eyes-off environments, and it also works in the offline runs (e.g., your local laptop or any non-AML virtual machines). For jobs in detonation chambers, please check out this internal note.

Before logging any metrics, call shrike.compliant_logging.enable_compliant_logging with the argument use_aml_metrics=True and category=DataCategory.PUBLIC to connect with the workspace portal and set up data-category-aware logging. Then continue to use the standard Python logging functionality as before. There are various metric-logging functions to match various metric types. The API references on all metric-logging functions are availabe on this page.

Use the following methods in the logging APIs for different scenarios & metric types.

Logged Value Example Code Supported Types
Log image log.metric_image(name='food', path='./breadpudding.jpg', plot=None, description='desert', category=DataCategory.PUBLIC) string, matplotlib.pyplot.plot
Log an array of numeric values log.metric_list(name="Fibonacci", value=[0, 1, 1, 2, 3, 5, 8], category=DataCategory.PUBLIC) list, tuple
Log a single value log.metric_value(name='metric_value', value=1, step=NA, category=DataCategory.PUBLIC)) scalar
Log a row with 2 numerical columns log.metric_row(name='Cosine Wave', angle=0, cos=1, category=DataCategory.PUBLIC)) scalar
Log a table log.metric_table(name="students", value={"name": ["James", "Robert", "Michael"], "number": [2, 3, 1, 5]}, category=DataCategory.PUBLIC) dict
Log residuals log.metric_residual(name="ml_residual", value=panda.DataFrame([[1.0, 1.1], [2.0, 2.0], [3.0, 3.1]],columns=["pred", "targ"]), col_predict="pred", col_target="targ", category=DataCategory.PUBLIC) dict, pandas.DataFrame, vaex.dataframe, spark dataframe
Log confusion matrix log.metric_confusion_matrix(name="animal_classification", value=vaex.from_arrays(x=["cat", "ant", "cat", "cat", "ant", "bird"], y=["ant", "ant", "cat", "cat", "ant", "cat"]), idx_true="x", idx_pred="y",category=DataCategory.PUBLI) dict, pandas.DataFrame, vaex.dataframe, spark dataframe
Log accuracy table log.metric_confusion_matrix(name="accuracy_table", value=vaex.from_arrays(x=[0.1, 0.3, 0.7], y=["a", "b", "c"]), idx_true="x", idx_pred="y",category=DataCategory.PUBLI) dict, pandas.DataFrame, vaex.dataframe, spark dataframe

Here is a full-fledged example:

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import argparse
import logging
import random
import matplotlib.pyplot as plt

from shrike.compliant_logging import enable_compliant_logging
from shrike.compliant_logging.constants import DataCategory


def run(args):
    n = args.list_length
    list1 = [random.randint(0, 100) for i in range(n)]
    list2 = [random.randint(0, 100) for i in range(n)]

    log = logging.getLogger(__name__)

    log.info(
        "Start metric logging in azure ml workspace portal",
        category=DataCategory.PUBLIC,
    )

    # log list
    log.metric_list(name="list1", value=list1, category=DataCategory.PUBLIC)

    # log table
    log.metric_table(
        name="Lists",
        value={"list1": list1, "list2": list2},
        category=DataCategory.PUBLIC,
    )

    # log scalar value
    log.metric(name="sum1", value=sum(list1), category=DataCategory.PUBLIC)
    log.metric(name="sum2", value=sum(list2), category=DataCategory.PUBLIC)

    # log image
    plt.plot(list1, list2)
    log.metric_image(name="Sample plot", plot=plt, category=DataCategory.PUBLIC)

    # log row
    for i in range(n):
        log.metric_row(
            name="pairwise-sum",
            description="",
            category=DataCategory.PUBLIC,
            pairwise_sum=list1[i] + list2[i],
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--prefix", default="SystemLog:")
    parser.add_argument("--log_level", default="INFO")
    parser.add_argument(
        "--list_length",
        required=False,
        default=5,
        type=int,
        help="length of test list",
    )
    args = parser.parse_args()

    enable_compliant_logging(
        args.prefix,
        level=args.log_level,
        format="%(prefix)s%(levelname)s:%(name)s:%(message)s",
        use_aml_metrics=True,
    )

    run(args)