Skip to content

Commit 36e5537

Browse files
committed
Fix weight column issue
1 parent 8e88992 commit 36e5537

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

00_Miscellaneous/tfx/02_tfx_end_to_end.ipynb

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"3. Model training with **TF Estimator**.\n",
1515
"4. Model evaluation with **TF Model Analysis**.\n",
1616
"\n",
17-
"<a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/machine_learning/sme_academy/02_tfx_end_to_end.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
17+
"<a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/tf-estimator-tutorials/blob/master/00_Miscellaneous/tfx/02_tfx_end_to_end.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
1818
]
1919
},
2020
{
@@ -469,7 +469,8 @@
469469
" for feature in raw_schema.feature:\n",
470470
" feature_name = feature.name\n",
471471
" \n",
472-
" if feature_name in ['income_bracket', 'fnlwgt']:\n",
472+
" # Pass the target feature as is.\n",
473+
" if feature_name == TARGET_FEATURE_NAME:\n",
473474
" processed_features[feature_name] = input_features[feature_name]\n",
474475
" continue\n",
475476
"\n",
@@ -480,6 +481,9 @@
480481
" # normalize numeric features.\n",
481482
" processed_features[feature_name+\"_scaled\"] = tft.scale_to_z_score(input_features[feature_name])\n",
482483
"\n",
484+
" # Pass the weight column\n",
485+
" processed_features[WEIGHT_COLUMN_NAME] = input_features[WEIGHT_COLUMN_NAME]\n",
486+
"\n",
483487
" # Bucketize age using quantiles. \n",
484488
" quantiles = tft.quantiles(input_features[\"age\"], num_buckets=5, epsilon=0.01)\n",
485489
" processed_features[\"age_bucketized\"] = tft.apply_buckets(\n",
@@ -536,6 +540,9 @@
536540
" # Load TFDV schema and create tft schema from it.\n",
537541
" source_raw_schema = tfdv.load_schema_text(raw_schema_location)\n",
538542
" raw_feature_spec = schema_utils.schema_as_feature_spec(source_raw_schema).feature_spec\n",
543+
" # Since the raw_feature_spec doesn't include the weight column, we need ot add it. \n",
544+
" raw_feature_spec[WEIGHT_COLUMN_NAME] = tf.FixedLenFeature(\n",
545+
" shape=[1], dtype=tf.int64, default_value=None)\n",
539546
" raw_metadata = dataset_metadata.DatasetMetadata(\n",
540547
" dataset_schema.from_feature_spec(raw_feature_spec))\n",
541548
"\n",
@@ -1096,7 +1103,6 @@
10961103
" source_raw_schema = tfdv.load_schema_text(RAW_SCHEMA_LOCATION)\n",
10971104
" raw_feature_spec = schema_utils.schema_as_feature_spec(source_raw_schema).feature_spec\n",
10981105
" raw_feature_spec.pop(TARGET_FEATURE_NAME)\n",
1099-
" raw_feature_spec.pop(WEIGHT_COLUMN_NAME)\n",
11001106
"\n",
11011107
" # Create the interface for the serving function with the raw features\n",
11021108
" raw_features = tf.estimator.export.build_parsing_serving_input_receiver_fn(raw_feature_spec)().features\n",
@@ -1137,7 +1143,7 @@
11371143
" \n",
11381144
"estimator.export_savedmodel(\n",
11391145
" export_dir_base=export_dir,\n",
1140-
" serving_input_receiver_fn=input_receiver_fn\n",
1146+
" serving_input_receiver_fn=serving_input_receiver_fn\n",
11411147
")"
11421148
]
11431149
},

0 commit comments

Comments
 (0)