Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added optimizer state tracking callback; added demo method for future…
… incorporation into unit tests
  • Loading branch information
gennadiryan committed Oct 24, 2025
commit b9f73fadc54b79b00a6b98ccea9f595972030137
37 changes: 31 additions & 6 deletions src/solvers/ipopt_solver/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,20 @@ function callback_update_trajectory_history_factory(problem::DirectTrajOptProble
end
end

"""
function callback_update_optimizer_state_history_factory(states::Vector{<:NamedTuple})

A callback factory returning a callback that populates `states` with a `deepcopy` of the `NamedTuple` containing the optimizer state at each iteration.
Useful for debugging.
"""

function callback_update_optimizer_state_history_factory(states::Vector{<:NamedTuple})
function _callback_update_optimizer_state_history(optimizer::Ipopt.Optimizer, optimizer_state::IpoptOptimizerState; kwargs...)
push!(states, deepcopy(optimizer_state))
return true
end
end

"""
function callback_rollout_fidelity_factory(problem::DirectTrajOptProblem, system::Any, fid_fn::Function; fid_thresh=nothing, freq=1)

Expand Down Expand Up @@ -222,7 +236,7 @@ function callback_best_rollout_fidelity_factory(problem::DirectTrajOptProblem, s
end
end

function cb_test()
function test_update_trajectory(update_trajectory::Bool)
include("test/test_utils.jl")

G, traj = bilinear_dynamics_and_trajectory()
Expand All @@ -242,18 +256,29 @@ function cb_test()

prob = DirectTrajOptProblem(traj, J, integrators; constraints=AbstractConstraint[g_u_norm])

callback = callback_factory(
callbacks = [
callback_say_hello_factory("Hello, world!"),
callback_stop_iteration_factory(50),
)
]
if update_trajectory
pushfirst!(callbacks, callback_update_trajectory_factory(prob))
end
callback = callback_factory(callbacks...)

traj_init = deepcopy(prob.traj)
trajs = NamedTrajectory[]

optimizer, variables = IpoptSolverExt.get_optimizer_and_variables(prob, IpoptOptions(; max_iter=100), callback)
IpoptSolverExt.MOI.optimize!(optimizer)

traj_final = prob.trajectory

optimizer, variables = get_optimizer_and_variables(prob, IpoptOptions(; max_iter=100), callback)
MOI.optimize!(optimizer)
@assert (traj_final == traj_init) == !update_trajectory

# solve!(prob; max_iter=100)
end

@testitem begin cb_test() end
# @testitem begin cb_test() end


end