Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Adding support for recovering rollout fidelities as dictionary post-s…
…olve
  • Loading branch information
gennadiryan committed Oct 24, 2025
commit 5154f4f82955c46a504608b1c5f49fd3333c8336
9 changes: 7 additions & 2 deletions src/solvers/ipopt_solver/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ function callback_update_optimizer_state_history_factory(states::Vector{<:NamedT
end

"""
function callback_rollout_fidelity_factory(problem::DirectTrajOptProblem, system::Any, fid_fn::Function; fid_thresh=nothing, freq=1)
function callback_rollout_fidelity_factory(problem::DirectTrajOptProblem, system::Any, fid_fn::Function, fidelities::Union{Nothing, Dict{Int32, <:Real}}=nothing; fid_thresh=nothing, freq=1)

A callback factory returning a callback that computes the rollout fidelity associated with an intermediate trajectory via `fid_fn(problem.trajectory, system)`, once every `freq` iterations, and stops the solver in its tracks if `!(fid_thresh isa Nothing) && fid >= fid_thresh`.
This is particularly useful for the early stages of a solve, when dynamics constraints are yet to be satisfied, during which time changes in the objective are a poor proxy for the true infidelity of the system at its final timestep.
Expand All @@ -166,7 +166,8 @@ This is particularly useful for the early stages of a solve, when dynamics const
- This callback expects that it be called after `_callback_update_trajectory`
- This callback is meant to be used with `QuantumCollocation`, though it is not strictly necessary; a custom rollout method may be used in place of e.g. `QuantumCollocation.unitary_rollout_fidelity`, as long as it has the correct type signature
"""
function callback_rollout_fidelity_factory(problem::DirectTrajOptProblem, system::Any, fid_fn::Function; fid_thresh=nothing, freq=1)

function callback_rollout_fidelity_factory(problem::DirectTrajOptProblem, system::Any, fid_fn::Function, fidelities::Union{Nothing, Dict{Int32, <:Real}}=nothing; fid_thresh=nothing, freq=1)
function _callback_rollout_fidelity(optimizer::Ipopt.Optimizer, optimizer_state::IpoptOptimizerState; kwargs...)
if optimizer_state.iter_count % freq != 0
return true
Expand All @@ -178,6 +179,10 @@ function callback_rollout_fidelity_factory(problem::DirectTrajOptProblem, system
println()
println("Fidelity: ", fid)

if !(fidelities isa Nothing)
push!(fidelities, Pair(optimizer_state.iter_count, fid))
end

return fid_thresh isa Nothing || fid < fid_thresh
end
end
Expand Down