##![LearnAI Header](https://coursematerial.blob.core.windows.net/assets/LearnAI_header.png)

# Machine Learning Experimentation and Model Management with AML Services

In this lab you will learn how to integrate Azure Databricks with Azure Machine Learning. 

This powerful combination allows you to do the following:
- Keep track of the performance of your various machine learning solutions
- Register the a candidate model for deployment
- Create a docker image
- Deploy your solution as a webservice

We now run a training experiment and use the Azure ML SDK to save it to our AML Workspace.

In [4]:
import os
import pprint
import numpy as np

from pyspark.ml import Pipeline, PipelineModel
from pyspark.ml.feature import OneHotEncoder, StringIndexer, VectorAssembler
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

In [5]:
import azureml.core

# Check core SDK version number
print("SDK version:", azureml.core.VERSION)

Let's load our Azure ML Workspace first:

In [7]:
# import the Workspace class and check the azureml SDK version
from azureml.core import Workspace

config_path = '/dbfs/tmp/'

ws = Workspace.from_config(path=os.path.join(config_path, 'aml_config','config.json'))
print('Workspace name: ' + ws.name, 
      'Azure region: ' + ws.location, 
      'Resource group: ' + ws.resource_group, sep = '\n')

## Read the data

In [9]:
df = spark.read.parquet("dbfs:/FileStore/tables/preprocessed").withColumnRenamed("y_0", "label").cache()
display(df)

machineID,datetime,age,diff_error_0,diff_error_1,diff_error_2,diff_error_3,diff_error_4,diff_fail_0,diff_fail_1,diff_fail_2,diff_fail_3,diff_maint_0,diff_maint_1,diff_maint_2,diff_maint_3,pressure_ma_3,pressure_sd_3,rotate_ma_3,rotate_sd_3,vibration_ma_3,vibration_sd_3,volt_ma_3,volt_sd_3,label,y_1,y_2,y_3
32,2015-02-23T21:00:00.000+0000,15,1095.0,166.0,409.0,1387.0,375.0,1071.0,1387.0,1387.0,351.0,711.0,1387.0,1387.0,351.0,99.47241720086753,4.027627389831299,451.3898073666962,37.143235255898006,36.43797824633008,5.246946040863718,175.77508414243,27.49122867502091,0,0,0,0
32,2015-02-23T22:00:00.000+0000,15,1096.0,167.0,410.0,1388.0,376.0,1072.0,1388.0,1388.0,352.0,712.0,1388.0,1388.0,352.0,100.09894109745454,3.794141420379903,435.17321536069,39.335699599244414,39.2308973568302,4.176148610708063,180.83567404661,23.740298384283044,0,0,0,0
32,2015-02-23T23:00:00.000+0000,15,1097.0,168.0,411.0,1389.0,377.0,1073.0,1389.0,1389.0,353.0,713.0,1389.0,1389.0,353.0,99.5443989022908,3.475883016729516,421.66272855994526,15.279105023168729,41.71121231384191,0.9869915716793244,169.48155636960098,21.84432149813136,0,0,0,0
32,2015-02-24T00:00:00.000+0000,15,1098.0,169.0,412.0,1390.0,378.0,1074.0,1390.0,1390.0,354.0,714.0,1390.0,1390.0,354.0,97.68908923167596,2.826179881000501,416.56100286652577,14.313070539126413,43.106649846394575,2.4962971618940304,168.40722316791698,19.97804980820004,0,0,0,0
32,2015-02-24T01:00:00.000+0000,15,1099.0,170.0,413.0,1391.0,379.0,1075.0,1391.0,1391.0,355.0,715.0,1391.0,1391.0,355.0,96.91916094721388,4.036157929664655,428.15124045533855,23.48933395357947,42.8987242528714,2.771261929609087,172.9402295425665,16.124035462907294,0,0,0,0
32,2015-02-24T02:00:00.000+0000,15,1100.0,171.0,414.0,1392.0,380.0,1076.0,1392.0,1392.0,356.0,716.0,1392.0,1392.0,356.0,98.60820305322396,6.429119080028224,444.72279625965575,26.72521659029804,43.55483208223855,2.8184999168524905,166.95757356868276,19.57722439213211,0,0,0,0
32,2015-02-24T03:00:00.000+0000,15,1101.0,172.0,415.0,1393.0,381.0,1077.0,1393.0,1393.0,357.0,717.0,1393.0,1393.0,357.0,97.75829917013183,6.35926449953383,455.000717946228,30.80270066186616,46.0032332920517,5.284176271790364,171.5146198226785,18.02024867702125,0,0,0,0
32,2015-02-24T04:00:00.000+0000,15,1102.0,173.0,416.0,1394.0,382.0,1078.0,1394.0,1394.0,358.0,718.0,1394.0,1394.0,358.0,97.46746997114926,6.507590246654807,464.6340331415173,13.781545445572954,44.00402649786548,6.376598773584249,163.08366794565802,10.130385844311215,0,0,0,0
32,2015-02-24T05:00:00.000+0000,15,1103.0,174.0,417.0,1395.0,383.0,1079.0,1395.0,1395.0,359.0,719.0,1395.0,1395.0,359.0,103.47412855441392,9.726682380322783,481.54989358335945,31.29419249089553,41.575592514816776,9.504942586795329,164.627163319354,11.547022705888777,0,0,0,0
32,2015-02-24T06:00:00.000+0000,15,1104.0,175.0,418.0,1396.0,384.0,1080.0,1396.0,1396.0,360.0,720.0,0.0,1396.0,360.0,99.97419460128005,10.652800910349333,479.901478683917,32.041360774820994,39.010816091632826,9.76504784353832,168.14560647024325,6.929237687861469,0,0,0,0


In [10]:
keys = ['machineID', 'datetime']
X_keep = ['diff_maint_1', 'diff_error_1', 'volt_sd_3', 'diff_fail_3', 'pressure_ma_3', 'pressure_sd_3', 'diff_fail_1', 'diff_fail_0', 'age', 'vibration_ma_3', 'rotate_ma_3', 'diff_error_2', 'diff_fail_2', 'diff_error_3', 'diff_maint_2', 'volt_ma_3', 'diff_maint_0', 'vibration_sd_3', 'diff_maint_3', 'rotate_sd_3', 'diff_error_0', 'diff_error_4']
Y_keep = ['y_0', 'y_1', 'y_2', 'y_3']

In [11]:
# from pyspark.sql.types import DateType
from pandas import datetime
from pyspark.sql.functions import col, hour

# we sample every nth row of the data using the `hour` function
df_train = df.filter((col('datetime') < datetime(2015, 10, 1))) # & (hour(col('datetime')) % 3 == 0))
df_test = df.filter(col('datetime') > datetime(2015, 10, 15))

#Define Model

It is time to run the experiment. To do so we load the root experiment and call the `start_logging` method. We then invoke each iteration of the experiment using the `run` and tell it which metrics to log. Examine the code below and see it all happening in action.

In [14]:
from azureml.core.run import Run
from azureml.core.experiment import Experiment
import numpy as np
import os
import shutil
from pyspark.ml.feature import StandardScaler, VectorAssembler
from pyspark.ml import Pipeline

vassembler = VectorAssembler(inputCols = X_keep, outputCol = "features")
stndscaler = StandardScaler(inputCol = "features", outputCol = "norm_features")

model_name = "PdM_logistic_regression.mml"
model_dbfs = os.path.join("/dbfs", model_name)
run_history_name = 'spark-ml-notebook'

# start a training run by defining an experiment
myexperiment = Experiment(ws, "AI_Airlft")
root_run = myexperiment.start_logging()

# Regularization Rates - 
regs = [0.0001, 0.001, 0.01, 0.1]
 
# try a bunch of regularization rate in a Logistic Regression model
for reg in regs:
    print("Regularization rate: {}".format(reg))
    # create a bunch of child runs
    with root_run.child_run("reg-" + str(reg)) as run:
        # create a new Logistic Regression model.
        lr = (LogisticRegression(regParam=reg)
              .setLabelCol("label")
              .setFeaturesCol("norm_features"))
        
        # put together the pipeline
        pipe = Pipeline(stages=[vassembler, stndscaler, lr])

        # train the model
        model_p = pipe.fit(df_train)
        
        # make prediction
        pred = model_p.transform(df_test)
        
        # evaluate. note only 2 metrics are supported out of the box by Spark ML.
        bce = (BinaryClassificationEvaluator()
               .setLabelCol("label")
               .setRawPredictionCol('rawPrediction'))
               
        au_roc = bce.setMetricName('areaUnderROC').evaluate(pred)
        au_prc = bce.setMetricName('areaUnderPR').evaluate(pred)

        print("Area under ROC: {}".format(au_roc))
        print("Area Under PR: {}".format(au_prc))
      
        # log reg, au_roc, au_prc and feature names in run history
        run.log("reg", reg)
        run.log("au_roc", au_roc)
        run.log("au_prc", au_prc)
        run.log_list("columns", df_train.columns)

        # save model
        model_p.write().overwrite().save(model_name)
        
        # upload the serialized model into run history record
        mdl, ext = model_name.split(".")
        model_zip = mdl + ".zip"
        shutil.make_archive(mdl, 'zip', model_dbfs)
        run.upload_file("outputs/" + model_name, model_zip)        
        # run.upload_file("outputs/" + model_name, path_or_stream = model_dbfs) #cannot deal with folders

        # now delete the serialized model from local folder since it is already uploaded to run history 
        shutil.rmtree(model_dbfs)
        os.remove(model_zip)
        
# Declare run completed
root_run.complete()
root_run_id = root_run.id
print ("run id:", root_run.id)

In [15]:
# load all run metrics from run history into a dictionary object
child_runs = {}

for r in root_run.get_children():
    child_runs[r.id] = r

In [16]:
child_runs

We can now select the best model based on the metric we choose.

In [18]:
metrics = root_run.get_metrics(recursive = True)
best_run_id = max(metrics, key = lambda k: metrics[k]['au_roc'])
best_run = child_runs[best_run_id]
print('Best run is:', best_run_id)
print('Metrics:', metrics[best_run_id]['au_roc'], metrics[best_run_id]['reg'])

We save the best model on disk for future use.

In [20]:
# download the model from the best run to a local folder
best_model_file_name = "best_model.zip"
best_run.download_file(name = 'outputs/' + model_name, output_file_path = best_model_file_name)

#Model Evaluation

We can load the best model we selected earlier and use it to evaluate its accuracy.

In [23]:
## unzip the model to dbfs (as load() seems to require that) and load it
if os.path.isfile(model_dbfs) or os.path.isdir(model_dbfs):
    shutil.rmtree(model_dbfs)
shutil.unpack_archive(best_model_file_name, model_dbfs)

model_p_best = PipelineModel.load(model_name)

In [24]:
# make prediction
df_pred = model_p_best.transform(df_test)
display(df_pred.groupBy("prediction", "label").count())

prediction,label,count
0.0,0,185712
0.0,1,2088


In [25]:
import pyspark.sql.functions as F

df_select = df_pred.orderBy(F.desc('prediction')).limit(5)
df_select = df_pred.union(df_pred.orderBy(F.asc('prediction')).limit(5))

display(df_select)

machineID,datetime,age,diff_error_0,diff_error_1,diff_error_2,diff_error_3,diff_error_4,diff_fail_0,diff_fail_1,diff_fail_2,diff_fail_3,diff_maint_0,diff_maint_1,diff_maint_2,diff_maint_3,pressure_ma_3,pressure_sd_3,rotate_ma_3,rotate_sd_3,vibration_ma_3,vibration_sd_3,volt_ma_3,volt_sd_3,label,y_1,y_2,y_3,features,norm_features,rawPrediction,probability,prediction
39,2015-10-15T01:00:00.000+0000,0,1759.0,278.0,67.0,231.0,153.0,1843.0,43.0,6983.0,6983.0,1123.0,43.0,1483.0,763.0,94.21543486894998,6.441336953515932,427.1108015695812,52.30227887677106,39.27557152618075,4.513392438930502,166.33543104143877,14.064982347527296,0,0,0,0,"List(1, 22, List(), List(43.0, 278.0, 14.064982347527296, 6983.0, 94.21543486894998, 6.441336953515932, 43.0, 1843.0, 0.0, 39.27557152618075, 427.1108015695812, 67.0, 6983.0, 231.0, 1483.0, 166.33543104143877, 1123.0, 4.513392438930502, 763.0, 52.30227887677106, 1759.0, 153.0))","List(1, 22, List(), List(0.04201172203737228, 0.2908104650747991, 2.406151349554706, 3.7410799147356126, 13.889709455086422, 1.6446668758830796, 0.02769192535651848, 1.1135081655890402, 0.0, 12.380794389583158, 14.4005506282904, 0.06739504841282014, 3.6077195113948237, 0.23213294447012395, 1.4278845707177827, 19.6892871754614, 1.1222425658147295, 2.3106445950514747, 0.7348350936167161, 2.684018700538735, 2.3049935361982463, 0.10720426303996886))","List(1, 2, List(), List(4.343159113187271, -4.343159113187271))","List(1, 2, List(), List(0.9871713048257431, 0.012828695174256999))",0.0
39,2015-10-15T02:00:00.000+0000,0,1760.0,279.0,68.0,232.0,154.0,1844.0,44.0,6984.0,6984.0,1124.0,44.0,1484.0,764.0,100.29140877617507,11.284847828477886,442.628398602288,57.79405244919041,40.4627650363102,5.15593570715991,159.08171336934024,4.742598565323689,0,0,0,0,"List(1, 22, List(), List(44.0, 279.0, 4.742598565323689, 6984.0, 100.29140877617509, 11.284847828477886, 44.0, 1844.0, 0.0, 40.4627650363102, 442.628398602288, 68.0, 6984.0, 232.0, 1484.0, 159.08171336934024, 1124.0, 5.15593570715991, 764.0, 57.79405244919041, 1760.0, 154.0))","List(1, 22, List(), List(0.042988738828939076, 0.29185654588442067, 0.8113348212168924, 3.7416156558088955, 14.785459841903942, 2.881360741848478, 0.02833592362062356, 1.1141123479903363, 0.0, 12.755032069097572, 14.92374493964405, 0.06840094465778761, 3.6082361546013817, 0.23313784899163964, 1.4288474058969585, 18.830657541104745, 1.1232418913408335, 2.6395965197754863, 0.7357981802400669, 2.9658424237884766, 2.306303936161975, 0.10790494449774643))","List(1, 2, List(), List(4.540672972913085, -4.540672972913085))","List(1, 2, List(), List(0.9894463416492582, 0.010553658350741709))",0.0
39,2015-10-15T03:00:00.000+0000,0,1761.0,280.0,69.0,233.0,155.0,1845.0,45.0,6985.0,6985.0,1125.0,45.0,1485.0,765.0,100.4415699380936,11.343353478846131,442.9228764783277,57.404479717880825,38.9660110968309,6.077115734667256,165.1331519461085,13.650043784109032,0,0,0,0,"List(1, 22, List(), List(45.0, 280.0, 13.650043784109032, 6985.0, 100.4415699380936, 11.343353478846133, 45.0, 1845.0, 0.0, 38.9660110968309, 442.92287647832774, 69.0, 6985.0, 233.0, 1485.0, 165.1331519461085, 1125.0, 6.077115734667256, 765.0, 57.404479717880825, 1761.0, 155.0))","List(1, 22, List(), List(0.04396575562050587, 0.2929026266940422, 2.335166192255403, 3.742151396882179, 14.80759734955741, 2.896298992386688, 0.02897992188472864, 1.1147165303916327, 0.0, 12.283212002414666, 14.933673612829645, 0.06940684090275506, 3.6087527978079397, 0.23414275351315533, 1.4298102410761344, 19.546972226536816, 1.1242412168669373, 3.111197356713476, 0.7367612668634178, 2.945850551187314, 2.307614336125703, 0.108605625955524))","List(1, 2, List(), List(4.386323525789271, -4.386323525789271))","List(1, 2, List(), List(0.9877066044769235, 0.01229339552307655))",0.0
39,2015-10-15T04:00:00.000+0000,0,1762.0,281.0,70.0,234.0,156.0,1846.0,46.0,6986.0,6986.0,1126.0,46.0,1486.0,766.0,102.36820210530072,10.50187836582813,470.885704935064,58.13539678531709,38.02902267033148,5.222749372837035,158.03387250251424,24.525674106431534,0,0,0,0,"List(1, 22, List(), List(46.0, 281.0, 24.525674106431538, 6986.0, 102.36820210530072, 10.50187836582813, 46.0, 1846.0, 0.0, 38.02902267033148, 470.885704935064, 70.0, 6986.0, 234.0, 1486.0, 158.03387250251424, 1126.0, 5.222749372837035, 766.0, 58.13539678531709, 1762.0, 156.0))","List(1, 22, List(), List(0.044942772412072673, 0.2939487075036638, 4.1957026601106175, 3.742687137955462, 15.091631075735643, 2.681445111080996, 0.02962392014883372, 1.115320712792929, 0.0, 11.987846190966819, 15.876473760757603, 0.07041273714772253, 3.6092694410144976, 0.23514765803467103, 1.4307730762553104, 18.70662359588966, 1.1252405423930414, 2.673801973336491, 0.7377243534867687, 2.983359338943275, 2.3089247360894314, 0.10930630741330158))","List(1, 2, List(), List(4.529832696140538, -4.529832696140538))","List(1, 2, List(), List(0.9893325417811886, 0.010667458218811363))",0.0
39,2015-10-15T05:00:00.000+0000,0,1763.0,282.0,71.0,235.0,157.0,1847.0,47.0,6987.0,6987.0,1127.0,47.0,1487.0,767.0,103.0400899854904,9.399464113043717,469.3866174883465,57.10461441576485,36.4770382878436,7.517098786147426,157.30145389586323,24.28877078859417,0,0,0,0,"List(1, 22, List(), List(47.0, 282.0, 24.288770788594174, 6987.0, 103.0400899854904, 9.399464113043717, 47.0, 1847.0, 0.0, 36.4770382878436, 469.3866174883465, 71.0, 6987.0, 235.0, 1487.0, 157.30145389586323, 1127.0, 7.517098786147426, 767.0, 57.10461441576485, 1763.0, 157.0))","List(1, 22, List(), List(0.04591978920363947, 0.2949947883132854, 4.155174686179063, 3.7432228790287447, 15.190684139123924, 2.3999656265981577, 0.030267918412938803, 1.1159248951942253, 0.0, 11.49861589364019, 15.825930237640517, 0.07141863339269, 3.6097860842210556, 0.2361525625561867, 1.431735911434486, 18.61992649119757, 1.1262398679191454, 3.848400934706978, 0.7386874401101196, 2.930462233588032, 2.31023513605316, 0.11000698887107915))","List(1, 2, List(), List(4.5437245989218855, -4.5437245989218855))","List(1, 2, List(), List(0.9894781600295124, 0.010521839970487525))",0.0
39,2015-10-15T06:00:00.000+0000,0,1764.0,283.0,72.0,236.0,158.0,1848.0,48.0,6988.0,6988.0,1128.0,48.0,1488.0,768.0,95.98438742852186,8.247213754633696,446.7774364487535,68.07805371841293,33.88723383411242,5.473339683893695,160.56317383021548,25.14330334174829,0,0,0,0,"List(1, 22, List(), List(48.0, 283.0, 25.14330334174829, 6988.0, 95.98438742852186, 8.247213754633696, 48.0, 1848.0, 0.0, 33.88723383411242, 446.7774364487535, 72.0, 6988.0, 236.0, 1488.0, 160.56317383021548, 1128.0, 5.473339683893695, 768.0, 68.07805371841293, 1764.0, 158.0))","List(1, 22, List(), List(0.046896805995206264, 0.296040869122907, 4.3013629006542615, 3.743758620102028, 14.150497266833625, 2.1057614868555588, 0.030911916677043884, 1.1165290775955214, 0.0, 10.682234738511768, 15.06363470442441, 0.07242452963765747, 3.6103027274276136, 0.23715746707770238, 1.432698746613662, 19.00601946051443, 1.1272391934452493, 2.8020924235133746, 0.7396505267334704, 3.4935909715714755, 2.3115455360168884, 0.11070767032885673))","List(1, 2, List(), List(4.440803532373727, -4.440803532373727))","List(1, 2, List(), List(0.9883508385130677, 0.011649161486932331))",0.0
39,2015-10-15T07:00:00.000+0000,0,1765.0,284.0,73.0,237.0,159.0,1849.0,49.0,6989.0,6989.0,1129.0,49.0,1489.0,769.0,95.37408813946335,7.516999772624291,472.58502381457726,55.953713871048606,33.14937094433025,5.626791017653033,158.9978155466815,23.294652524522714,0,0,0,0,"List(1, 22, List(), List(49.0, 284.0, 23.294652524522714, 6989.0, 95.37408813946337, 7.516999772624291, 49.0, 1849.0, 0.0, 33.14937094433025, 472.58502381457726, 73.0, 6989.0, 237.0, 1489.0, 158.9978155466815, 1129.0, 5.626791017653033, 769.0, 55.953713871048606, 1765.0, 159.0))","List(1, 22, List(), List(0.04787382278677306, 0.29708694993252854, 3.9851070000910633, 3.744294361175311, 14.060523900819243, 1.9193159155115505, 0.03155591494114896, 1.1171332599968178, 0.0, 10.449639046810447, 15.933768325699797, 0.07343042588262493, 3.610819370634172, 0.23816237159921808, 1.4336615817928378, 18.820726474023118, 1.1282385189713533, 2.880652287241616, 0.7406136133568213, 2.871400971807132, 2.3128559359806165, 0.1114083517866343))","List(1, 2, List(), List(4.475183118610085, -4.475183118610085))","List(1, 2, List(), List(0.9887400928068701, 0.011259907193129807))",0.0
39,2015-10-15T08:00:00.000+0000,0,1766.0,285.0,74.0,238.0,160.0,1850.0,50.0,6990.0,6990.0,1130.0,50.0,1490.0,770.0,96.26219706125636,8.636840646400076,472.7027662545422,56.07018134255616,32.99660722936027,5.363287900106483,169.107364792492,6.8666067034999125,0,0,0,0,"List(1, 22, List(), List(50.0, 285.0, 6.8666067034999125, 6990.0, 96.26219706125636, 8.636840646400076, 50.0, 1850.0, 0.0, 32.99660722936027, 472.7027662545422, 74.0, 6990.0, 238.0, 1490.0, 169.107364792492, 1130.0, 5.363287900106483, 770.0, 56.070181342556154, 1766.0, 160.0))","List(1, 22, List(), List(0.048850839578339855, 0.29813303074215014, 1.174697171901704, 3.744830102248594, 14.191453348901005, 2.205244939974972, 0.03219991320525405, 1.1177374423981141, 0.0, 10.401483512167967, 15.937738152643023, 0.0744363221275924, 3.61133601384073, 0.23916727612073377, 1.4346244169720137, 20.017403676642946, 1.1292378444974571, 2.745751087628133, 0.7415766999801722, 2.8773777834919008, 2.314166335944345, 0.11210903324441188))","List(1, 2, List(), List(4.270838577709067, -4.270838577709067))","List(1, 2, List(), List(0.9862224105399156, 0.013777589460084323))",0.0
39,2015-10-15T09:00:00.000+0000,0,1767.0,286.0,75.0,239.0,161.0,1851.0,51.0,6991.0,6991.0,1131.0,51.0,1491.0,771.0,98.9952616527612,8.483691997905721,442.9278262463732,69.08354465208387,37.52180701259352,6.2733621222245,172.82673604134072,5.7636521611340905,0,0,0,0,"List(1, 22, List(), List(51.0, 286.0, 5.7636521611340905, 6991.0, 98.9952616527612, 8.483691997905721, 51.0, 1851.0, 0.0, 37.52180701259352, 442.92782624637323, 75.0, 6991.0, 239.0, 1491.0, 172.82673604134072, 1131.0, 6.2733621222245, 771.0, 69.08354465208387, 1767.0, 161.0))","List(1, 22, List(), List(0.04982785636990666, 0.29917911155177174, 0.986010439488024, 3.7453658433218773, 14.594375366411082, 2.1661414881477175, 0.032843911469359124, 1.1183416247994102, 0.0, 11.827957167698377, 14.933840500169127, 0.07544221837255986, 3.611852657047288, 0.24017218064224946, 1.4355872521511894, 20.457669278338432, 1.1302371700235612, 3.21166627468219, 0.742539786603523, 3.5451901853562795, 2.3154767359080735, 0.11280971470218944))","List(1, 2, List(), List(4.212181151192999, -4.212181151192999))","List(1, 2, List(), List(0.985402230221059, 0.014597769778941054))",0.0
39,2015-10-15T10:00:00.000+0000,0,1768.0,287.0,76.0,240.0,162.0,1852.0,52.0,6992.0,6992.0,1132.0,52.0,1492.0,772.0,101.01374646724754,4.573114412731661,470.33976732648273,62.5704339150958,41.89108118541365,8.538073710017958,173.772954179621,5.518722666805894,0,0,0,0,"List(1, 22, List(), List(52.0, 287.0, 5.518722666805894, 6992.0, 101.01374646724754, 4.573114412731661, 52.0, 1852.0, 0.0, 41.89108118541365, 470.33976732648273, 76.0, 6992.0, 240.0, 1492.0, 173.772954179621, 1132.0, 8.538073710017958, 772.0, 62.5704339150958, 1768.0, 162.0))","List(1, 22, List(), List(0.05080487316147345, 0.3002251923613933, 0.9441093962615353, 3.74590158439516, 14.891950468109828, 1.1676535241861357, 0.03348790973346421, 1.1189458072007066, 0.0, 13.205278567829824, 15.858066822456417, 0.07644811461752732, 3.612369300253846, 0.24117708516376515, 1.4365500873303654, 20.569674042077487, 1.131236495549665, 4.371092063706953, 0.7435028732268739, 3.2109540604267486, 2.316787135871802, 0.11351039615996701))","List(1, 2, List(), List(4.213761333802227, -4.213761333802227))","List(1, 2, List(), List(0.9854249432080338, 0.014575056791966296))",0.0


In [26]:
# evaluate. note only 2 metrics are supported out of the box by Spark ML.
bce = (BinaryClassificationEvaluator()
               .setLabelCol("label")
               .setRawPredictionCol('rawPrediction'))
au_roc = bce.setMetricName('areaUnderROC').evaluate(df_pred)
au_prc = bce.setMetricName('areaUnderPR').evaluate(df_pred)

print("Area under ROC: {}".format(au_roc))
print("Area Under PR: {}".format(au_prc))

#Model Persistence

In [28]:
print(model_name[:-4])

In [29]:
## NOTE: by default the model is saved to and loaded from /dbfs/ instead of cwd!
model_p_best.write().overwrite().save(model_name[:-4])
print("saved model to {}".format(model_dbfs))

In [30]:
%sh

ls -la /dbfs/PdM_logistic_regression/*

In [31]:
# You can ignore this code, we use it for testing our notebooks.
assert au_roc > .82

Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.