|
205 | 205 | "container = get_image_uri(session.boto_region_name, 'xgboost')\n", |
206 | 206 | "\n", |
207 | 207 | "# We now specify the parameters we wish to use for our training job\n", |
208 | | - "training_params = \\\n", |
209 | | - "{\n", |
210 | | - " # We need to specify the permissions that this training job will have. For our purposes we can use\n", |
211 | | - " # the same permissions that our current SageMaker session has.\n", |
212 | | - " \"RoleArn\": role,\n", |
213 | | - " \n", |
214 | | - " # Here we describe the algorithm we wish to use. The most important part is the container which\n", |
215 | | - " # contains the training code.\n", |
216 | | - " \"AlgorithmSpecification\": {\n", |
217 | | - " \"TrainingImage\": container,\n", |
218 | | - " \"TrainingInputMode\": \"File\"\n", |
219 | | - " },\n", |
220 | | - " \n", |
221 | | - " # Next we set the algorithm specific hyperparameters. You may wish to change these to see what effect\n", |
222 | | - " # there is on the resulting model.\n", |
223 | | - " \"HyperParameters\": {\n", |
224 | | - " \"max_depth\": \"5\",\n", |
225 | | - " \"eta\": \"0.2\",\n", |
226 | | - " \"gamma\": \"4\",\n", |
227 | | - " \"min_child_weight\": \"6\",\n", |
228 | | - " \"subsample\": \"0.8\",\n", |
229 | | - " \"objective\": \"reg:linear\",\n", |
230 | | - " \"early_stopping_rounds\": \"10\",\n", |
231 | | - " \"num_round\": \"200\"\n", |
232 | | - " },\n", |
| 208 | + "training_params = {}\n", |
| 209 | + "\n", |
| 210 | + "# We need to specify the permissions that this training job will have. For our purposes we can use\n", |
| 211 | + "# the same permissions that our current SageMaker session has.\n", |
| 212 | + "training_params['RoleArn'] = role\n", |
| 213 | + "\n", |
| 214 | + "# Here we describe the algorithm we wish to use. The most important part is the container which\n", |
| 215 | + "# contains the training code.\n", |
| 216 | + "training_params['AlgorithmSpecification'] = {\n", |
| 217 | + " \"TrainingImage\": container,\n", |
| 218 | + " \"TrainingInputMode\": \"File\"\n", |
| 219 | + "}\n", |
| 220 | + "\n", |
| 221 | + "# We also need to say where we would like the resulting model artifacst stored.\n", |
| 222 | + "training_params['OutputDataConfig'] = {\n", |
| 223 | + " \"S3OutputPath\": \"s3://\" + session.default_bucket() + \"/\" + prefix + \"/output\"\n", |
| 224 | + "}\n", |
| 225 | + "\n", |
| 226 | + "# We also need to set some parameters for the training job itself. Namely we need to describe what sort of\n", |
| 227 | + "# compute instance we wish to use along with a stopping condition to handle the case that there is\n", |
| 228 | + "# some sort of error and the training script doesn't terminate.\n", |
| 229 | + "training_params['ResourceConfig'] = {\n", |
| 230 | + " \"InstanceCount\": 1,\n", |
| 231 | + " \"InstanceType\": \"ml.m4.xlarge\",\n", |
| 232 | + " \"VolumeSizeInGB\": 5\n", |
| 233 | + "}\n", |
233 | 234 | " \n", |
234 | | - " # Now we need to tell SageMaker where the data should be retrieved from and where to save the\n", |
235 | | - " # resulting model artifacts.\n", |
236 | | - " \"InputDataConfig\": [\n", |
237 | | - " {\n", |
238 | | - " \"ChannelName\": \"train\",\n", |
239 | | - " \"DataSource\": {\n", |
240 | | - " \"S3DataSource\": {\n", |
241 | | - " \"S3DataType\": \"S3Prefix\",\n", |
242 | | - " \"S3Uri\": train_location,\n", |
243 | | - " \"S3DataDistributionType\": \"FullyReplicated\"\n", |
244 | | - " }\n", |
245 | | - " },\n", |
246 | | - " \"ContentType\": \"csv\",\n", |
247 | | - " \"CompressionType\": \"None\"\n", |
| 235 | + "training_params['StoppingCondition'] = {\n", |
| 236 | + " \"MaxRuntimeInSeconds\": 86400\n", |
| 237 | + "}\n", |
| 238 | + "\n", |
| 239 | + "# Next we set the algorithm specific hyperparameters. You may wish to change these to see what effect\n", |
| 240 | + "# there is on the resulting model.\n", |
| 241 | + "training_params['HyperParameters'] = {\n", |
| 242 | + " \"max_depth\": \"5\",\n", |
| 243 | + " \"eta\": \"0.2\",\n", |
| 244 | + " \"gamma\": \"4\",\n", |
| 245 | + " \"min_child_weight\": \"6\",\n", |
| 246 | + " \"subsample\": \"0.8\",\n", |
| 247 | + " \"objective\": \"reg:linear\",\n", |
| 248 | + " \"early_stopping_rounds\": \"10\",\n", |
| 249 | + " \"num_round\": \"200\"\n", |
| 250 | + "}\n", |
| 251 | + "\n", |
| 252 | + "# Now we need to tell SageMaker where the data should be retrieved from.\n", |
| 253 | + "training_params['InputDataConfig'] = [\n", |
| 254 | + " {\n", |
| 255 | + " \"ChannelName\": \"train\",\n", |
| 256 | + " \"DataSource\": {\n", |
| 257 | + " \"S3DataSource\": {\n", |
| 258 | + " \"S3DataType\": \"S3Prefix\",\n", |
| 259 | + " \"S3Uri\": train_location,\n", |
| 260 | + " \"S3DataDistributionType\": \"FullyReplicated\"\n", |
| 261 | + " }\n", |
248 | 262 | " },\n", |
249 | | - " {\n", |
250 | | - " \"ChannelName\": \"validation\",\n", |
251 | | - " \"DataSource\": {\n", |
252 | | - " \"S3DataSource\": {\n", |
253 | | - " \"S3DataType\": \"S3Prefix\",\n", |
254 | | - " \"S3Uri\": val_location,\n", |
255 | | - " \"S3DataDistributionType\": \"FullyReplicated\"\n", |
256 | | - " }\n", |
257 | | - " },\n", |
258 | | - " \"ContentType\": \"csv\",\n", |
259 | | - " \"CompressionType\": \"None\"\n", |
260 | | - " }\n", |
261 | | - " ],\n", |
262 | | - " \n", |
263 | | - " \"OutputDataConfig\": {\n", |
264 | | - " \"S3OutputPath\": \"s3://\" + session.default_bucket() + \"/\" + prefix + \"/output\"\n", |
265 | | - " },\n", |
266 | | - " \n", |
267 | | - " # Lastly we set some parameters for the training job itself. Namely we need to describe what sort of\n", |
268 | | - " # compute instance we wish to use along with a stopping condition to handle the case that there is\n", |
269 | | - " # some sort of error and the training script doesn't terminate.\n", |
270 | | - " \"ResourceConfig\": {\n", |
271 | | - " \"InstanceCount\": 1,\n", |
272 | | - " \"InstanceType\": \"ml.m4.xlarge\",\n", |
273 | | - " \"VolumeSizeInGB\": 5\n", |
| 263 | + " \"ContentType\": \"csv\",\n", |
| 264 | + " \"CompressionType\": \"None\"\n", |
274 | 265 | " },\n", |
275 | | - " \n", |
276 | | - " \"StoppingCondition\": {\n", |
277 | | - " \"MaxRuntimeInSeconds\": 86400\n", |
| 266 | + " {\n", |
| 267 | + " \"ChannelName\": \"validation\",\n", |
| 268 | + " \"DataSource\": {\n", |
| 269 | + " \"S3DataSource\": {\n", |
| 270 | + " \"S3DataType\": \"S3Prefix\",\n", |
| 271 | + " \"S3Uri\": val_location,\n", |
| 272 | + " \"S3DataDistributionType\": \"FullyReplicated\"\n", |
| 273 | + " }\n", |
| 274 | + " },\n", |
| 275 | + " \"ContentType\": \"csv\",\n", |
| 276 | + " \"CompressionType\": \"None\"\n", |
278 | 277 | " }\n", |
279 | | - "}" |
| 278 | + "]" |
280 | 279 | ] |
281 | 280 | }, |
282 | 281 | { |
|
0 commit comments