|
82 | 82 | "from sklearn import svm\n", |
83 | 83 | "from sklearn.preprocessing import LabelEncoder, StandardScaler\n", |
84 | 84 | "from sklearn.linear_model import LogisticRegression\n", |
85 | | - "import pandas as pd\n", |
86 | | - "import shap" |
| 85 | + "import pandas as pd" |
87 | 86 | ] |
88 | 87 | }, |
89 | 88 | { |
|
99 | 98 | "metadata": {}, |
100 | 99 | "outputs": [], |
101 | 100 | "source": [ |
102 | | - "X_raw, Y = shap.datasets.adult()\n", |
103 | | - "X_raw[\"Race\"].value_counts().to_dict()" |
| 101 | + "from sklearn.datasets import fetch_openml\n", |
| 102 | + "data = fetch_openml(data_id=1590, as_frame=True)\n", |
| 103 | + "X_raw = data.data\n", |
| 104 | + "Y = (data.target == '>50K') * 1\n", |
| 105 | + "\n", |
| 106 | + "X_raw[\"race\"].value_counts().to_dict()" |
104 | 107 | ] |
105 | 108 | }, |
106 | 109 | { |
|
116 | 119 | "metadata": {}, |
117 | 120 | "outputs": [], |
118 | 121 | "source": [ |
119 | | - "A = X_raw[['Sex','Race']]\n", |
120 | | - "X = X_raw.drop(labels=['Sex', 'Race'],axis = 1)\n", |
121 | | - "X = pd.get_dummies(X)\n", |
| 122 | + "A = X_raw[['sex','race']]\n", |
| 123 | + "X = X_raw.drop(labels=['sex', 'race'],axis = 1)\n", |
| 124 | + "X_dummies = pd.get_dummies(X)\n", |
| 125 | + "\n", |
| 126 | + "sc = StandardScaler()\n", |
| 127 | + "X_scaled = sc.fit_transform(X_dummies)\n", |
| 128 | + "X_scaled = pd.DataFrame(X_scaled, columns=X_dummies.columns)\n", |
122 | 129 | "\n", |
123 | 130 | "\n", |
124 | 131 | "le = LabelEncoder()\n", |
|
139 | 146 | "outputs": [], |
140 | 147 | "source": [ |
141 | 148 | "from sklearn.model_selection import train_test_split\n", |
142 | | - "X_train, X_test, Y_train, Y_test, A_train, A_test = train_test_split(X_raw, \n", |
| 149 | + "X_train, X_test, Y_train, Y_test, A_train, A_test = train_test_split(X_scaled, \n", |
143 | 150 | " Y, \n", |
144 | 151 | " A,\n", |
145 | 152 | " test_size = 0.2,\n", |
|
150 | 157 | "X_train = X_train.reset_index(drop=True)\n", |
151 | 158 | "A_train = A_train.reset_index(drop=True)\n", |
152 | 159 | "X_test = X_test.reset_index(drop=True)\n", |
153 | | - "A_test = A_test.reset_index(drop=True)\n", |
154 | | - "\n", |
155 | | - "# Improve labels\n", |
156 | | - "A_test.Sex.loc[(A_test['Sex'] == 0)] = 'female'\n", |
157 | | - "A_test.Sex.loc[(A_test['Sex'] == 1)] = 'male'\n", |
158 | | - "\n", |
159 | | - "\n", |
160 | | - "A_test.Race.loc[(A_test['Race'] == 0)] = 'Amer-Indian-Eskimo'\n", |
161 | | - "A_test.Race.loc[(A_test['Race'] == 1)] = 'Asian-Pac-Islander'\n", |
162 | | - "A_test.Race.loc[(A_test['Race'] == 2)] = 'Black'\n", |
163 | | - "A_test.Race.loc[(A_test['Race'] == 3)] = 'Other'\n", |
164 | | - "A_test.Race.loc[(A_test['Race'] == 4)] = 'White'" |
| 160 | + "A_test = A_test.reset_index(drop=True)" |
165 | 161 | ] |
166 | 162 | }, |
167 | 163 | { |
|
251 | 247 | "outputs": [], |
252 | 248 | "source": [ |
253 | 249 | "sweep.fit(X_train, Y_train,\n", |
254 | | - " sensitive_features=A_train.Sex)\n", |
| 250 | + " sensitive_features=A_train.sex)\n", |
255 | 251 | "\n", |
256 | 252 | "predictors = sweep._predictors" |
257 | 253 | ] |
|
274 | 270 | " classifier = lambda X: m.predict(X)\n", |
275 | 271 | " \n", |
276 | 272 | " error = ErrorRate()\n", |
277 | | - " error.load_data(X_train, pd.Series(Y_train), sensitive_features=A_train.Sex)\n", |
| 273 | + " error.load_data(X_train, pd.Series(Y_train), sensitive_features=A_train.sex)\n", |
278 | 274 | " disparity = DemographicParity()\n", |
279 | | - " disparity.load_data(X_train, pd.Series(Y_train), sensitive_features=A_train.Sex)\n", |
| 275 | + " disparity.load_data(X_train, pd.Series(Y_train), sensitive_features=A_train.sex)\n", |
280 | 276 | " \n", |
281 | 277 | " errors.append(error.gamma(classifier)[0])\n", |
282 | 278 | " disparities.append(disparity.gamma(classifier).max())\n", |
|
440 | 436 | "metadata": {}, |
441 | 437 | "outputs": [], |
442 | 438 | "source": [ |
443 | | - "sf = { 'sex': A_test.Sex, 'race': A_test.Race }\n", |
| 439 | + "sf = { 'sex': A_test.sex, 'race': A_test.race }\n", |
444 | 440 | "\n", |
445 | 441 | "from fairlearn.metrics._group_metric_set import _create_group_metric_set\n", |
446 | 442 | "\n", |
|
0 commit comments