Machine learning interpretability#
In modern day machine learning it is important to be able to explain how our models “think”. A simple accuracy score isn’t enough. This notebook explores the lesson on interpretability.
Machine learning interpretability is an increasingly important topic in artificial intelligence.
As machine learning models become more complex, understanding how they make predictions is becoming more difficult. This lack of transparency can lead to a lack of trust in the model. It can make it difficult to identify and correct errors. Interpretability is the ability to explain how a machine learning model arrived at a particular decision. It is essential to build trust and understanding in these powerful tools.
This notebook will explore the importance of interpretability and provide practical examples of how it can be achieved.
How To#
from sklearn.model_selection import train_test_split
import pandas as pd
df = pd.read_csv("data/housing.csv")
df.head()
longitude | latitude | housing_median_age | total_rooms | total_bedrooms | population | households | median_income | median_house_value | ocean_proximity | |
---|---|---|---|---|---|---|---|---|---|---|
0 | -122.23 | 37.88 | 41.0 | 880.0 | 129.0 | 322.0 | 126.0 | 8.3252 | 452600.0 | NEAR BAY |
1 | -122.22 | 37.86 | 21.0 | 7099.0 | 1106.0 | 2401.0 | 1138.0 | 8.3014 | 358500.0 | NEAR BAY |
2 | -122.24 | 37.85 | 52.0 | 1467.0 | 190.0 | 496.0 | 177.0 | 7.2574 | 352100.0 | NEAR BAY |
3 | -122.25 | 37.85 | 52.0 | 1274.0 | 235.0 | 558.0 | 219.0 | 5.6431 | 341300.0 | NEAR BAY |
4 | -122.25 | 37.85 | 52.0 | 1627.0 | 280.0 | 565.0 | 259.0 | 3.8462 | 342200.0 | NEAR BAY |
df = df.dropna()
x_train, x_, y_train, y_ = train_test_split(df.drop(["longitude","latitude", "ocean_proximity", "median_house_value"], axis=1),
df.median_house_value, test_size=.5, stratify=df.ocean_proximity)
x_val, x_test, y_val, y_test = train_test_split(x_, y_, test_size=.5)
from sklearn.ensemble import RandomForestRegressor
model = RandomForestRegressor()
model.fit(x_train, y_train)
RandomForestRegressor()
model.score(x_val, y_val)
0.6268484183250347
Influence of Variables#
import eli5
eli5.explain_weights(model)
Weight | Feature |
---|---|
0.5702 ± 0.0193 | x5 |
0.1007 ± 0.0122 | x3 |
0.0999 ± 0.0110 | x0 |
0.0869 ± 0.0136 | x2 |
0.0750 ± 0.0112 | x1 |
0.0672 ± 0.0126 | x4 |
for x in range(5):
display(eli5.explain_prediction(model, x_train.iloc[x, :]))
y (score 192155.010) top features
Contribution? | Feature |
---|---|
+206608.756 | <BIAS> |
+23749.665 | population |
+4237.507 | total_rooms |
-2911.587 | households |
-7329.428 | total_bedrooms |
-8065.454 | median_income |
-24134.450 | housing_median_age |
y (score 219967.000) top features
Contribution? | Feature |
---|---|
+206608.756 | <BIAS> |
+29156.010 | median_income |
+13227.272 | total_bedrooms |
+7458.089 | households |
-1647.261 | housing_median_age |
-2133.702 | total_rooms |
-32702.165 | population |
y (score 164051.000) top features
Contribution? | Feature |
---|---|
+206608.756 | <BIAS> |
+636.694 | total_bedrooms |
-286.107 | total_rooms |
-347.708 | median_income |
-1879.641 | households |
-2102.477 | population |
-38578.517 | housing_median_age |
y (score 90957.000) top features
Contribution? | Feature |
---|---|
+206608.756 | <BIAS> |
+2167.274 | total_bedrooms |
-9794.188 | households |
-13850.356 | housing_median_age |
-18001.863 | population |
-23218.913 | total_rooms |
-52953.711 | median_income |
y (score 369616.040) top features
Contribution? | Feature |
---|---|
+206608.756 | <BIAS> |
+177655.659 | median_income |
+8726.026 | total_bedrooms |
-1901.498 | total_rooms |
-3043.081 | households |
-7012.409 | housing_median_age |
-11417.412 | population |
from sklearn.inspection import permutation_importance
permutation_importance(model, x_train, y_train)
{'importances_mean': array([0.32936848, 0.23403235, 0.41975637, 0.39078343, 0.32032987,
1.56666893]),
'importances_std': array([0.00342922, 0.00260159, 0.00483743, 0.00583837, 0.00345139,
0.01029038]),
'importances': array([[0.33037129, 0.33300722, 0.33172831, 0.32855788, 0.32317768],
[0.23873456, 0.23493173, 0.23165493, 0.23223061, 0.23260994],
[0.42385855, 0.42404604, 0.41103279, 0.42159832, 0.41824618],
[0.39994676, 0.38680695, 0.38822469, 0.38396244, 0.39497633],
[0.31744121, 0.32319307, 0.31515865, 0.32410591, 0.32175051],
[1.57431233, 1.56523047, 1.5607892 , 1.5812431 , 1.55176957]])}
from sklearn.inspection import plot_partial_dependence
plot_partial_dependence(model, x_train, x_train.columns)
<sklearn.inspection._plot.partial_dependence.PartialDependenceDisplay at 0x7ff38ea48280>

Shap#
import shap
expl = shap.TreeExplainer(model)
shap.TreeExplainer(model, data=x_train)
<shap.explainers.tree.TreeExplainer at 0x7ff38c493160>
shap_val = expl.shap_values(x_val)
shap.initjs()
shap.force_plot(expl.expected_value, shap_val[0, :], x_val.iloc[0, :])
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
Exercise#
Check out shap
further and see which plots you can generate.