This repository contains the official codes for the paper "Explainability of Machine Learning Models under Missing Data".
The purpose of this research is to evaluate the impact of different data imputation methods on the explainability of machine learning models. Specifically, we introduce missing data into various datasets, apply several imputation techniques (e.g., Mean Imputation, MICE, missForest, GAIN), and then train XGBoost models.
We use SHAP (SHapley Additive exPlanations) to explain the predictions of these models. By comparing the SHAP values from models trained on imputed data to the SHAP values from a model trained on the original, complete data, we can quantify how well each imputation method preserves model explainability.
For the code to run correctly, your project should follow this structure. The funcs directory contains the implementations for the different imputation methods and utility functions used in the notebooks.
.
├── funcs/
│ ├── utils.py # Utility functions (e.g., generate_missing_data)
│ ├── explain.py # Main helper functions for running experiments
│ ├── explainNumpy.py # Helper functions for numpy data (used in MNIST)
│ ├── DIMV.py # DIMV imputation implementation
│ ├── miss_forest.py # missForest imputation implementation
│ └── GAIN/
│ └── gain.py # GAIN imputation implementation
├── results/ # Directory where output plots/tables are saved
│
├── XGB_clf_glass_rate 02.ipynb
├── XGBRegressor_california_rate02.ipynb
├── XGBRegressor_diabetes_rate02.ipynb
├── XGB mnist with GAIN 02.ipynb
└── README.md
Note: You must create the results/ directory yourself before running the notebooks.
It is highly recommended to use a virtual environment (like conda or venv) to manage dependencies.
-
Create and activate a new virtual environment.
-
Install the required libraries. The main dependencies are:
jupyterscikit-learnnumpypandasseabornmatplotlibshapxgboosttensorflow(for GAIN imputation)torch&torchvision(for the MNIST experiment)
You can install them via pip:
pip install jupyter sklearn numpy pandas seaborn matplotlib shap xgboost tensorflow torch torchvision
-
Ensure the
funcsdirectory is in the same root folder as the notebooks. The notebooks import imputation methods and helpers directly from this folder.
Each Jupyter Notebook (.ipynb) represents a self-contained experiment on a specific dataset. To keep the repo clean, we upload the notesbook for only 20% missing rate. However, higher missing rates can be easily achieved by changing the parameter missing_rate in each notebook.
-
XGB_clf_glass_rate 02.ipynb- Task: Classification
- Model:
xgboost.XGBClassifier - Dataset: Glass Identification (downloaded automatically from the UCI ML repository)
- Imputers: Mean, MICE, DIMV, missForest, SOFT-IMPUTE, GAIN.
-
XGBRegressor_diabetes_rate02.ipynb- Task: Regression
- Model:
xgboost.XGBRegressor - Dataset: Diabetes (loaded from
sklearn.datasets) - Imputers: Mean, MICE, DIMV, missForest, SOFT-IMPUTE, GAIN.
-
XGBRegressor_california_rate02.ipynb- Task: Regression
- Model:
xgboost.XGBRegressor - Dataset: California Housing (loaded from
shap.datasets) - Imputers: Mean, MICE, DIMV, missForest, SOFT-IMPUTE, GAIN.
-
XGB mnist with GAIN 02.ipynb- Task: Classification
- Model:
xgboost.XGBClassifier - Dataset: MNIST (downloaded automatically via
torchvision.datasets) - Imputers: Mean, MICE, DIMV, missForest, SOFT-IMPUTE, GAIN.
Each notebook follows a similar workflow:
- Load Data: The dataset is loaded (either from a library or downloaded).
- Preprocessing: Data is split into training/test sets and standardized.
- Generate Missingness: A copy of the data is created with a 20% missing rate (
X_train_star,X_test_star). - Imputation: The data with missing values is imputed using various methods.
- Model Training: An XGBoost model is trained on the original (complete) data and on each of the imputed datasets.
- SHAP Analysis: SHAP values are calculated for the test set predictions of each model.
- Evaluation: The "MSE Shap" (Mean Squared Error between SHAP values from the imputed model and the original model) is calculated to measure the impact on explainability. Other metrics like prediction MSE are also computed.
- Results: Tables and plots comparing the performance and explainability metrics are generated and saved to the
results/folder.