-
Notifications
You must be signed in to change notification settings - Fork 4
Adding Ipopt callbacks to DTO (previously implemented in QC) #30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 4 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
60ae85e
Experimenting with two variations on callback generating factory pattern
gennadiryan 7f01474
Adding some examples (not yet tests)
gennadiryan 553ed17
Minor cleanup
gennadiryan 37dd964
Typo fix
gennadiryan 3a73727
Adding new callbacks
gennadiryan 8c12548
Latest commit
gennadiryan e6639fe
Adding more descriptive docstrings; considering adding the bulk of th…
gennadiryan 4362cf5
Commenting out examples
gennadiryan fef84cf
Added a brief note re trajectory updates mid-solve vs at end of solve
gennadiryan 200e95f
Trying out callback test
gennadiryan b9f73fa
Added optimizer state tracking callback; added demo method for future…
gennadiryan 677e682
Tests are broken
gennadiryan 29987de
Fixed test implementation
gennadiryan 5154f4f
Adding support for recovering rollout fidelities as dictionary post-s…
gennadiryan 46cc295
Verbump to 0.4.2
gennadiryan 829e19f
Merge remote-tracking branch 'origin/main' into feature/ipopt-callbacks
gennadiryan 2c6ff36
Verbump to v0.4.3
gennadiryan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| module Callbacks | ||
|
|
||
| using ..DirectTrajOpt | ||
| using Ipopt | ||
|
|
||
| """ | ||
| # Callbacks evaluated by Ipopt should have the following signature: | ||
| # Note that Cint === Int32 && Cdouble === Float64 | ||
|
|
||
| function my_intermediate_callback( | ||
| alg_mod::Cint, | ||
| iter_count::Cint, | ||
| obj_value::Float64, | ||
| inf_pr::Float64, | ||
| inf_du::Float64, | ||
| mu::Float64, | ||
| d_norm::Float64, | ||
| regularization_size::Float64, | ||
| alpha_du::Float64, | ||
| alpha_pr::Float64, | ||
| ls_trials::Cint, | ||
| ) | ||
| # ... user code ... | ||
| return true # or `return false` to terminate the solve. | ||
| end | ||
| """ | ||
|
|
||
| """ | ||
| Take 1 | ||
| """ | ||
|
|
||
| function callback_update_trajectory(problem::DirectTrajOptProblem; callback=nothing) | ||
| function __callback(optimizer::Ipopt.Optimizer) | ||
| function _callback(args...) | ||
| IpoptSolverExt.update_trajectory!(prob, optimizer, optimizer.list_of_variable_indices) | ||
| if callback isa Nothing | ||
| return true | ||
| end | ||
| # by now, the trajectory is up to date, so `callback` can make use of it for e.g. rollouts | ||
| return callback(args...) | ||
| end | ||
| return _callback | ||
| end | ||
| return __callback | ||
| end | ||
|
|
||
| function callback_update_trajectory_with_rollout(problem::DirectTrajOptProblem, fid_fn::Function; callback=nothing, fid_thresh=0.99, freq=1) | ||
| function __callback(optimizer::Ipopt.Optimizer) | ||
| function _callback(args...) | ||
| IpoptSolverExt.update_trajectory!(prob, optimizer, optimizer.list_of_variable_indices) | ||
|
|
||
| res = (callback isa Nothing) || (callback(args...)) | ||
|
|
||
| # we should evaluate `fid_fn` every `freq` iterations even if !res | ||
| if args[2] % freq == 0 | ||
| res_fid = fid_fn(prob.trajectory) < fid_thresh | ||
| return res && res_fid | ||
| end | ||
|
|
||
| return res | ||
| end | ||
| return _callback | ||
| end | ||
| return __callback | ||
| end | ||
|
|
||
| """ | ||
| Take 2 | ||
| """ | ||
|
|
||
| """ | ||
| # Example usage: | ||
| # | ||
| # # The following solve should proceed as usual, printing the current fidelity (as computed by unitary_rollout_fidelity) once every 10 iterations, and stopping once it exceeds 0.999 | ||
| # > initial = unitary_rollout_fidelity(prob.trajectory, sys) | ||
| # > cb = callback_factory(_callback_update_trajectory_factory(prob), _callback_rollout_fidelity_factory(prob, sys, unitary_rollout_fidelity; fid_thresh=0.999, freq=10)) | ||
| # > solve!(prob; max_iter=100, callback=cb) | ||
| # > final = unitary_rollout_fidelity(prob.trajectory, sys) | ||
| # > @assert final > initial | ||
| # | ||
| # # Terminating the solve manually (via Ctrl+C) will result in the final fidelity matching the initial fidelity (loss of solver progress) if _callback_update_trajectory is omitted | ||
| # > do_traj_update = false | ||
| # > initial = unitary_rollout_fidelity(prob.trajectory, sys) | ||
| # > cb = callback_factory((do_traj_update ? [_callback_update_trajectory_factory(prob)] : [])...) | ||
| # > solve!(prob; max_iter=100, callback=cb) | ||
| # > final = unitary_rollout_fidelity(prob.trajectory, sys) | ||
| # > @assert (final == initial) == (!do_traj_update) | ||
| """ | ||
| function callback_factory(callbacks...; kwargs...) | ||
| function _callback_factory(optimizer::Ipopt.Optimizer) | ||
| function _callback(optimizer_state...) | ||
| for callback in callbacks | ||
| res = callback(optimizer, optimizer_state; kwargs) | ||
| if !res | ||
| return false | ||
| end | ||
| end | ||
| return true | ||
| end | ||
| end | ||
| end | ||
|
|
||
| function _callback_say_hello_factory(msg) | ||
| function _callback_say_hello(optimizer, optimizer_state; kwargs...) | ||
| println(msg) | ||
| return true | ||
| end | ||
| end | ||
|
|
||
| function _callback_stop_iteration_factory(stop_iteration) | ||
| function _callback_stop_iteration(optimizer, optimizer_state; kwargs...) | ||
| if optimizer_state[2] >= stop_iteration | ||
gennadiryan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return false | ||
| end | ||
| return true | ||
| end | ||
| end | ||
|
|
||
| function _callback_update_trajectory_factory(problem) | ||
| function _callback_update_trajectory(optimizer, optimizer_state; kwargs...) | ||
| IpoptSolverExt.update_trajectory!(problem, optimizer, optimizer.list_of_variable_indices) | ||
| return true | ||
| end | ||
| end | ||
|
|
||
| """ | ||
| # WARNING: This callback expects that _callback_update_trajectory was evaluated beforehand | ||
| # However, a custom callback can just as well do both in one go, especially if the overhead from doing a trajectory update once per iteration is undesirable | ||
| """ | ||
| function _callback_rollout_fidelity_factory(problem, system, fid_fn; fid_thresh=nothing, freq=1) | ||
| function _callback_rollout_fidelity(optimizer, optimizer_state; kwargs...) | ||
| if optimizer_state[2] % freq != 0 | ||
| return true | ||
| end | ||
|
|
||
| fid = fid_fn(problem.trajectory, system) | ||
|
|
||
| # Probably comment this out and/or customize display of fidelities | ||
| println() | ||
| println("Fidelity: ", fid) | ||
|
|
||
| return fid_thresh isa Nothing || fid < fid_thresh | ||
| end | ||
| end | ||
|
|
||
| end | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.