|  | 
| 451 | 451 |       "metadata": {}, | 
| 452 | 452 |       "source": [ | 
| 453 | 453 |         "### Create a dataset of training artifacts\n", | 
| 454 |  | -        "To evaluate a trained policy (a checkpoint) we need to make the checkpoint accessible to the rollout script. All the training artifacts are stored in workspace default datastore under **azureml/<run_id>** directory.\n", | 
| 455 |  | -        "\n", | 
| 456 |  | -        "Here we create a file dataset from the stored artifacts, and then use this dataset to feed these data to rollout estimator." | 
|  | 454 | +        "To evaluate a trained policy (a checkpoint) we need to make the checkpoint accessible to the rollout script.\n", | 
|  | 455 | +        "We can use the Run API to download policy training artifacts (saved model and checkpoints) to local compute." | 
| 457 | 456 |       ] | 
| 458 | 457 |     }, | 
| 459 | 458 |     { | 
|  | 
| 462 | 461 |       "metadata": {}, | 
| 463 | 462 |       "outputs": [], | 
| 464 | 463 |       "source": [ | 
| 465 |  | -        "from azureml.core import Dataset\n", | 
|  | 464 | +        "from os import path\n", | 
|  | 465 | +        "from distutils import dir_util\n", | 
| 466 | 466 |         "\n", | 
| 467 |  | -        "run_id = child_run_0.id # Or set to run id of a completed run (e.g. 'rl-cartpole-v0_1587572312_06e04ace_head')\n", | 
| 468 |  | -        "run_artifacts_path = os.path.join('azureml', run_id)\n", | 
| 469 |  | -        "print(\"Run artifacts path:\", run_artifacts_path)\n", | 
|  | 467 | +        "training_artifacts_path = path.join(\"logs\", training_algorithm)\n", | 
|  | 468 | +        "print(\"Training artifacts path:\", training_artifacts_path)\n", | 
| 470 | 469 |         "\n", | 
| 471 |  | -        "# Create a file dataset object from the files stored on default datastore\n", | 
| 472 |  | -        "datastore = ws.get_default_datastore()\n", | 
| 473 |  | -        "training_artifacts_ds = Dataset.File.from_files(datastore.path(os.path.join(run_artifacts_path, '**')))" | 
|  | 470 | +        "if path.exists(training_artifacts_path):\n", | 
|  | 471 | +        "    dir_util.remove_tree(training_artifacts_path)\n", | 
|  | 472 | +        "\n", | 
|  | 473 | +        "# Download run artifacts to local compute\n", | 
|  | 474 | +        "child_run_0.download_files(training_artifacts_path)" | 
| 474 | 475 |       ] | 
| 475 | 476 |     }, | 
| 476 | 477 |     { | 
| 477 | 478 |       "cell_type": "markdown", | 
| 478 | 479 |       "metadata": {}, | 
| 479 | 480 |       "source": [ | 
| 480 |  | -        "To verify, we can print out the number (and paths) of all the files in the dataset, as follows." | 
|  | 481 | +        "Now let's find the checkpoints and the last checkpoint number." | 
| 481 | 482 |       ] | 
| 482 | 483 |     }, | 
| 483 | 484 |     { | 
|  | 
| 486 | 487 |       "metadata": {}, | 
| 487 | 488 |       "outputs": [], | 
| 488 | 489 |       "source": [ | 
| 489 |  | -        "artifacts_paths = training_artifacts_ds.to_path()\n", | 
| 490 |  | -        "print(\"Number of files in dataset:\", len(artifacts_paths))\n", | 
| 491 |  | -        "\n", | 
| 492 |  | -        "# Uncomment line below to print all file paths\n", | 
| 493 |  | -        "#print(\"Artifacts dataset file paths: \", artifacts_paths)" | 
|  | 490 | +        "# A helper function to find checkpoint files in a directory\n", | 
|  | 491 | +        "def find_checkpoints(file_path):\n", | 
|  | 492 | +        "    print(\"Looking in path:\", file_path)\n", | 
|  | 493 | +        "    checkpoints = []\n", | 
|  | 494 | +        "    for root, _, files in os.walk(file_path):\n", | 
|  | 495 | +        "        for name in files:\n", | 
|  | 496 | +        "            if os.path.basename(root).startswith('checkpoint_'):\n", | 
|  | 497 | +        "                checkpoints.append(path.join(root, name))\n", | 
|  | 498 | +        "    return checkpoints" | 
| 494 | 499 |       ] | 
| 495 | 500 |     }, | 
| 496 | 501 |     { | 
| 497 |  | -      "cell_type": "markdown", | 
|  | 502 | +      "cell_type": "code", | 
|  | 503 | +      "execution_count": null, | 
| 498 | 504 |       "metadata": {}, | 
|  | 505 | +      "outputs": [], | 
| 499 | 506 |       "source": [ | 
| 500 |  | -        "### Evaluate a trained policy\n", | 
| 501 |  | -        "We need to configure another reinforcement learning estimator, `rollout_estimator`, and then use it to submit another run. Note that the entry script for this estimator now points to `cartpole-rollout.py` script.\n", | 
| 502 |  | -        "Also note how we pass the checkpoints dataset to this script using `inputs` parameter of the _ReinforcementLearningEstimator_.\n", | 
|  | 507 | +        "# Find checkpoints and last checkpoint number\n", | 
|  | 508 | +        "checkpoint_files = find_checkpoints(training_artifacts_path)\n", | 
| 503 | 509 |         "\n", | 
| 504 |  | -        "We are using script parameters to pass in the same algorithm and the same environment used during training. We also specify the checkpoint number of the checkpoint we wish to evaluate, `checkpoint-number`, and number of the steps we shall run the rollout, `steps`.\n", | 
|  | 510 | +        "checkpoint_numbers = []\n", | 
|  | 511 | +        "for file in checkpoint_files:\n", | 
|  | 512 | +        "    file = os.path.basename(file)\n", | 
|  | 513 | +        "    if file.startswith('checkpoint-') and not file.endswith('.tune_metadata'):\n", | 
|  | 514 | +        "        checkpoint_numbers.append(int(file.split('-')[1]))\n", | 
| 505 | 515 |         "\n", | 
| 506 |  | -        "The checkpoints dataset will be accessible to the rollout script as a mounted folder. The mounted folder and the checkpoint number, passed in via `checkpoint-number`, will be used to create a path to the checkpoint we are going to evaluate. The created checkpoint path then will be passed into RLlib rollout script for evaluation.\n", | 
|  | 516 | +        "print(\"Checkpoints:\", checkpoint_numbers)\n", | 
| 507 | 517 |         "\n", | 
| 508 |  | -        "Let's find the checkpoints and the last checkpoint number first." | 
|  | 518 | +        "last_checkpoint_number = max(checkpoint_numbers)\n", | 
|  | 519 | +        "print(\"Last checkpoint number:\", last_checkpoint_number)" | 
|  | 520 | +      ] | 
|  | 521 | +    }, | 
|  | 522 | +    { | 
|  | 523 | +      "cell_type": "markdown", | 
|  | 524 | +      "metadata": {}, | 
|  | 525 | +      "source": [ | 
|  | 526 | +        "Now we upload checkpoints to default datastore and create a file dataset. This dataset will be used to pass in the checkpoints to the rollout script." | 
| 509 | 527 |       ] | 
| 510 | 528 |     }, | 
| 511 | 529 |     { | 
|  | 
| 514 | 532 |       "metadata": {}, | 
| 515 | 533 |       "outputs": [], | 
| 516 | 534 |       "source": [ | 
| 517 |  | -        "# Find checkpoints and last checkpoint number\n", | 
| 518 |  | -        "checkpoint_files = [\n", | 
| 519 |  | -        "    os.path.basename(file) for file in training_artifacts_ds.to_path() \\\n", | 
| 520 |  | -        "        if os.path.basename(file).startswith('checkpoint-') and \\\n", | 
| 521 |  | -        "            not os.path.basename(file).endswith('tune_metadata')\n", | 
| 522 |  | -        "]\n", | 
| 523 |  | -        "\n", | 
| 524 |  | -        "checkpoint_numbers = []\n", | 
| 525 |  | -        "for file in checkpoint_files:\n", | 
| 526 |  | -        "    checkpoint_numbers.append(int(file.split('-')[1]))\n", | 
|  | 535 | +        "# Upload the checkpoint files and create a DataSet\n", | 
|  | 536 | +        "from azureml.core import Dataset\n", | 
| 527 | 537 |         "\n", | 
| 528 |  | -        "print(\"Checkpoints:\", checkpoint_numbers)\n", | 
|  | 538 | +        "datastore = ws.get_default_datastore()\n", | 
|  | 539 | +        "checkpoint_dataref = datastore.upload_files(checkpoint_files, target_path='cartpole_checkpoints_' + run_id, overwrite=True)\n", | 
|  | 540 | +        "checkpoint_ds = Dataset.File.from_files(checkpoint_dataref)" | 
|  | 541 | +      ] | 
|  | 542 | +    }, | 
|  | 543 | +    { | 
|  | 544 | +      "cell_type": "markdown", | 
|  | 545 | +      "metadata": {}, | 
|  | 546 | +      "source": [ | 
|  | 547 | +        "To verify, we can print out the number (and paths) of all the files in the dataset." | 
|  | 548 | +      ] | 
|  | 549 | +    }, | 
|  | 550 | +    { | 
|  | 551 | +      "cell_type": "code", | 
|  | 552 | +      "execution_count": null, | 
|  | 553 | +      "metadata": {}, | 
|  | 554 | +      "outputs": [], | 
|  | 555 | +      "source": [ | 
|  | 556 | +        "artifacts_paths = checkpoint_ds.to_path()\n", | 
|  | 557 | +        "print(\"Number of files in dataset:\", len(artifacts_paths))\n", | 
| 529 | 558 |         "\n", | 
| 530 |  | -        "last_checkpoint_number = max(checkpoint_numbers)\n", | 
| 531 |  | -        "print(\"Last checkpoint number:\", last_checkpoint_number)" | 
|  | 559 | +        "# Uncomment line below to print all file paths\n", | 
|  | 560 | +        "#print(\"Artifacts dataset file paths: \", artifacts_paths)" | 
| 532 | 561 |       ] | 
| 533 | 562 |     }, | 
| 534 | 563 |     { | 
| 535 | 564 |       "cell_type": "markdown", | 
| 536 | 565 |       "metadata": {}, | 
| 537 | 566 |       "source": [ | 
|  | 567 | +        "### Evaluate a trained policy\n", | 
|  | 568 | +        "We need to configure another reinforcement learning estimator, `rollout_estimator`, and then use it to submit another run. Note that the entry script for this estimator now points to `cartpole-rollout.py` script.\n", | 
|  | 569 | +        "Also note how we pass the checkpoints dataset to this script using `inputs` parameter of the _ReinforcementLearningEstimator_.\n", | 
|  | 570 | +        "\n", | 
|  | 571 | +        "We are using script parameters to pass in the same algorithm and the same environment used during training. We also specify the checkpoint number of the checkpoint we wish to evaluate, `checkpoint-number`, and number of the steps we shall run the rollout, `steps`.\n", | 
|  | 572 | +        "\n", | 
|  | 573 | +        "The checkpoints dataset will be accessible to the rollout script as a mounted folder. The mounted folder and the checkpoint number, passed in via `checkpoint-number`, will be used to create a path to the checkpoint we are going to evaluate. The created checkpoint path then will be passed into RLlib rollout script for evaluation.\n", | 
|  | 574 | +        "\n", | 
| 538 | 575 |         "Now let's configure rollout estimator. Note that we use the last checkpoint for evaluation. The assumption is that the last checkpoint points to our best trained agent. You may change this to any of the checkpoint numbers printed above and observe the effect." | 
| 539 | 576 |       ] | 
| 540 | 577 |     }, | 
|  | 
| 576 | 613 |         "    \n", | 
| 577 | 614 |         "    # Data inputs\n", | 
| 578 | 615 |         "    inputs=[\n", | 
| 579 |  | -        "        training_artifacts_ds.as_named_input('artifacts_dataset'),\n", | 
| 580 |  | -        "        training_artifacts_ds.as_named_input('artifacts_path').as_mount()],\n", | 
|  | 616 | +        "        checkpoint_ds.as_named_input('artifacts_dataset'),\n", | 
|  | 617 | +        "        checkpoint_ds.as_named_input('artifacts_path').as_mount()],\n", | 
| 581 | 618 |         "    \n", | 
| 582 | 619 |         "    # The Azure Machine Learning compute target\n", | 
| 583 | 620 |         "    compute_target=compute_target,\n", | 
|  | 
0 commit comments