Skip to content

Commit be00e66

Browse files
committed
added explanation again
1 parent 58ca1c2 commit be00e66

2 files changed

Lines changed: 26 additions & 1 deletion

File tree

code/training/train.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,15 @@
3030
from sklearn.linear_model import Ridge
3131
from sklearn.metrics import mean_squared_error
3232
from sklearn.model_selection import train_test_split
33+
3334
from sklearn.externals import joblib
3435
import numpy as np
3536

37+
from interpret.ext.blackbox import TabularExplainer
38+
from azureml.contrib.explain.model.explanation.explanation_client import (
39+
ExplanationClient
40+
)
41+
3642
parser = argparse.ArgumentParser("train")
3743
parser.add_argument(
3844
"--release_id",
@@ -57,6 +63,7 @@
5763
run = Run.get_context()
5864
exp = run.experiment
5965
ws = run.experiment.workspace
66+
client = ExplanationClient.from_run(run)
6067

6168
X, y = load_diabetes(return_X_y=True)
6269
columns = ["age", "gender", "bmi", "bp", "s1", "s2", "s3", "s4", "s5", "s6"]
@@ -77,6 +84,19 @@
7784
preds = reg.predict(data["test"]["X"])
7885
run.log("mse", mean_squared_error(preds, data["test"]["y"]))
7986

87+
# create an explainer to validate or debug the model
88+
tabular_explainer = TabularExplainer(reg,
89+
initialization_examples=X_train,
90+
features=columns)
91+
# explain overall model predictions (global explanation)
92+
# passing in test dataset for evaluation examples
93+
94+
global_explanation = tabular_explainer.explain_global(X_test)
95+
96+
# uploading model explanation data for storage or visualization
97+
comment = 'Global explanation on of Diabetes Regression'
98+
client.upload_model_explanation(global_explanation, comment=comment)
99+
80100
with open(model_name, "wb") as file:
81101
joblib.dump(value=reg, filename=model_name)
82102

ml_service/pipelines/build_train_pipeline.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,12 @@ def main():
5252
'scikit-learn', 'tensorflow', 'keras'],
5353
pip_packages=['azure', 'azureml-core',
5454
'azure-storage',
55-
'azure-storage-blob'])
55+
'azure-storage-blob', 'azureml-defaults',
56+
'azureml-contrib-interpret',
57+
'azureml-telemetry',
58+
'azureml-interpret',
59+
'sklearn-pandas',
60+
'azureml-dataprep'])
5661
)
5762
run_config.environment.docker.enabled = True
5863

0 commit comments

Comments
 (0)