Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
7f9c195
feat: add boilerplate and inner function call
jack-champagne Sep 25, 2024
7f35b40
add callbacks tests
jack-champagne Sep 25, 2024
f25a057
add tmp callback changes
jack-champagne Nov 11, 2024
1a0e8fb
Merge branch 'main' into feature/ipopt-callbacks
jack-champagne Nov 11, 2024
f88cde8
add docstrings
jack-champagne Nov 11, 2024
81cbad7
add generic callback factory
jack-champagne Nov 11, 2024
6b4a6cc
Merge commit '18d52526351f71093d7bca9b111779ac00e77bb4' into feature/…
andgoldschmidt Nov 11, 2024
6428fa6
add get best iterate to callbacks lib
jack-champagne Nov 12, 2024
70d5fa1
fix renamed callback
jack-champagne Nov 12, 2024
972aea9
move callback tests, rename callbacks
andgoldschmidt Nov 12, 2024
587c051
Merge branch 'feature/ipopt-callbacks' of github.com:kestrelquantum/Q…
andgoldschmidt Nov 12, 2024
bc406f3
local project toml changes
jack-champagne Nov 12, 2024
d267338
remove pulse animation pngs
jack-champagne Nov 12, 2024
ddb9f0e
add NamedTrajectories to docs compat
jack-champagne Nov 12, 2024
e5e64c6
fix typo NM version
jack-champagne Nov 12, 2024
1701f19
roll back callback fidelity
andgoldschmidt Nov 12, 2024
e8dabda
Merge branch 'feature/ipopt-callbacks' of github.com:kestrelquantum/Q…
andgoldschmidt Nov 12, 2024
0ecef44
rm manifest
jack-champagne Nov 12, 2024
1c4d641
test unitary fid callback
andgoldschmidt Nov 12, 2024
f712b40
fix doctests
jack-champagne Nov 12, 2024
e4f7cbc
fixup docs
jack-champagne Nov 12, 2024
a1586aa
callbacks version bump
jack-champagne Nov 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions docs/literate/man/ipopt_callbacks.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# # IpOpt Solver Callbacks

# This page describes the callback functions that can be used with the IpOpt solver (in the future, may describe more general callback behavior).

# ## Callbacks

# By default, IpOpt callbacks are called at each optimization step with the following signature:
using QuantumCollocation
using NamedTrajectories

import ..QuantumStateSmoothPulseProblem
import ..Callbacks

function get_history_callback(
alg_mod::Cint,
iter_count::Cint,
obj_value::Float64,
inf_pr::Float64,
inf_du::Float64,
mu::Float64,
d_norm::Float64,
regularization_size::Float64,
alpha_du::Float64,
alpha_pr::Float64,
ls_trials::Cint,
)
return true
end

# This gives the user access to some of the optimization state internals at each iteration.
# A callback function with any subset of these arguments can be passed into the `solve!` function via the `callback` keyword argument see below.
# ```@docs
# ProblemSolvers.solve!(prob::QuantumControlProblem; callback=nothing)
# ```

# The callback function can be used to monitor the optimization progress, save intermediate results, or modify the optimization process.
# For example, the following callback function saves the optimization trajectory at each iteration:

trajectory_history = []
function get_history_callback(
kwargs...
)
push!(trajectory_history, QuantumCollocation.Problems.get_datavec(prob))
return true
end

# The callback function can also be used to stop the optimization early by returning `false`. The following callback when passed to `solve!` will stop the optimization after the first iteration:
my_callback = (kwargs...) -> false

T = 50
Δt = 0.2
sys = QuantumSystem(0.1 * GATES[:Z], [GATES[:X], GATES[:Y]])
ψ_init = Vector{ComplexF64}([1.0, 0.0])
ψ_target = Vector{ComplexF64}([0.0, 1.0])

# Single initial and target states
# --------------------------------
prob = QuantumStateSmoothPulseProblem(
sys, ψ_init, ψ_target, T, Δt;
ipopt_options=IpoptOptions(print_level=1),
piccolo_options=PiccoloOptions(verbose=false)
)

trajectory_history = []
# using the get_history_callback function to save the optimization trajectory at each iteration into the trajectory_history
solve!(prob, max_iter=20, callback=get_history_callback)

for (iter, traj) in enumerate(trajectory_history)
# get the length of the trajectory history depending on length and left pad the index with leading zeros
str_index = lpad(iter, length(string(length(trajectory_history))), "0")
# plot the trajectory but on fixed xaxis and yaxis
plot("./iteration-$str_index-trajectory.png", NamedTrajectory(traj, prob.trajectory), [:ψ̃1, :a], xlims=(-Δt, (T+5)*Δt), plot_ylims=(ψ̃1 = (-2, 2), a = (-1.1, 1.1)))
end

# Using a callback to get the best trajectory from all the optimization iterations
T = 50
Δt = 0.2
sys = QuantumSystem(0.1 * GATES[:Z], [GATES[:X], GATES[:Y]])
ψ_init = Vector{ComplexF64}([0.0, 1.0])
ψ_target = Vector{ComplexF64}([1.0, 0.0])

# Single initial and target states
# --------------------------------
prob = QuantumStateSmoothPulseProblem(
sys, ψ_init, ψ_target, T, Δt;
ipopt_options=IpoptOptions(print_level=1),
piccolo_options=PiccoloOptions(verbose=false)
)

(best_traj_callback, best_traj) = make_save_best_trajectory_callback(prob, prob.trajectory)
solve!(prob, max_iter=20, callback=best_traj_callback)
# fidelity of the last iterate
@show fidelity(prob)
# fidelity of the best iterate
@show fidelity(best_traj)
Binary file added docs/src/assets/animation.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/src/assets/pulse-optimization.gif
3 changes: 3 additions & 0 deletions src/QuantumCollocation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,8 @@ include("problem_solvers.jl")
include("plotting.jl")
@reexport using .Plotting

include("callbacks.jl")
@reexport using .Callbacks


end
26 changes: 26 additions & 0 deletions src/callbacks.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module Callbacks

export make_save_best_trajectory_callback

using NamedTrajectories

import ..QuantumControlProblem
import ..Problems: get_datavec

function make_save_best_trajectory_callback(prob::QuantumControlProblem, traj::NamedTrajectory)
best_traj = NamedTrajectory(get_datavec(prob), prob.trajectory)
best_fidelity = 0.0

function save_best_trajectory_callback(prob::QuantumControlProblem)
fidelity = prob.objective.fidelity(prob)
if fidelity > best_fidelity
best_fidelity = fidelity
copy!(best_traj, prob.trajectory)
end
return true
end

return save_best_trajectory_callback, best_traj
end

end
39 changes: 39 additions & 0 deletions src/problem_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,42 @@ using ..Options

using NamedTrajectories
using MathOptInterface
using Ipopt
const MOI = MathOptInterface


@doc raw"""
solve!(prob::QuantumControlProblem;
init_traj=nothing,
save_path=nothing,
max_iter=prob.ipopt_options.max_iter,
linear_solver=prob.ipopt_options.linear_solver,
print_level=prob.ipopt_options.print_level,
remove_slack_variables=false,
callback=nothing
# state_type=:unitary,
# print_fidelity=false,
)

Call optimization solver to solve the quantum control problem with parameters and callbacks.

prob: QuantumControlProblem
The quantum control problem to solve.
init_traj: NamedTrajectory, optional
Initial guess for the control trajectory. If not provided, a random guess will be generated.
save_path: String, optional
Path to save the problem after optimization.
max_iter: Int, optional
Maximum number of iterations for the optimization solver.
linear_solver: String, optional (e.g., "mumps", "paradiso", etc)
Linear solver to use for the optimization solver.
print_level: Int, optional
Verbosity level for the solver
remove_slack_variables: Bool, optional
Remove slack variables from the trajectory after optimization.
callback: Function, optional
Callback function to call during optimization steps.
"""
function solve!(
prob::QuantumControlProblem;
init_traj=nothing,
Expand All @@ -19,6 +53,7 @@ function solve!(
linear_solver::String=prob.ipopt_options.linear_solver,
print_level::Int=prob.ipopt_options.print_level,
remove_slack_variables::Bool=false,
callback=nothing
# state_type::Symbol=:unitary,
# print_fidelity::Bool=false,
)
Expand All @@ -35,6 +70,10 @@ function solve!(
else
set_trajectory!(prob)
end

if !isnothing(callback)
MOI.set(prob.optimizer, Ipopt.CallbackFunction(), callback)
end

# if print_fidelity
# if state_type == :ket
Expand Down
119 changes: 119 additions & 0 deletions test/callbacks_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
Testing callback features

- callback gives early stopping
- callback gives trajectory data
- callback example test with full argument list
"""

@testitem "Callback returns false early stops" begin
using MathOptInterface
const MOI = MathOptInterface

T = 50
Δt = 0.2
sys = QuantumSystem(0.1 * GATES[:Z], [GATES[:X], GATES[:Y]])
ψ_init = [1.0, 0.0]
ψ_target = [0.0, 1.0]

# Single initial and target states
# --------------------------------
prob = QuantumStateSmoothPulseProblem(
sys, ψ_init, ψ_target, T, Δt;
ipopt_options=IpoptOptions(print_level=1),
piccolo_options=PiccoloOptions(verbose=false)
)

my_callback = (kwargs...) -> false

initial = fidelity(prob)
solve!(prob, max_iter=20, callback=my_callback)
final = fidelity(prob)

# callback forces problem to exit early as per Ipopt documentation
@test MOI.get(prob.optimizer, MOI.TerminationStatus()) == MOI.INTERRUPTED
@test initial ≈ final atol=1e-2
end


@testitem "Callback can get internal history" begin
using MathOptInterface
using NamedTrajectories
const MOI = MathOptInterface

T = 50
Δt = 0.2
sys = QuantumSystem(0.1 * GATES[:Z], [GATES[:X], GATES[:Y]])
ψ_init = [1.0, 0.0]
ψ_target = [0.0, 1.0]

# Single initial and target states
# --------------------------------
prob = QuantumStateSmoothPulseProblem(
sys, ψ_init, ψ_target, T, Δt;
ipopt_options=IpoptOptions(print_level=1),
piccolo_options=PiccoloOptions(verbose=false)
)

trajectory_history = []
function get_history_callback(
kwargs...
)
push!(trajectory_history, QuantumCollocation.Problems.get_datavec(prob))
return true
end

initial = fidelity(prob)
solve!(prob, max_iter=20, callback=get_history_callback)

# for (iter, traj) in enumerate(trajectory_history)
# # change next line to plot the trajectory but on fixed xaxis and yaxis
# # get the length of the trajectory history depending on length and left pad the index with leading zeros
# str_index = lpad(iter, length(string(length(trajectory_history))), "0")
# plot("./test_animation_frames/iteration-$str_index-trajectory.png", NamedTrajectory(traj, prob.trajectory), [:ψ̃1, :a], xlims=(-Δt, (T+5)*Δt), plot_ylims=(ψ̃1 = (-2, 2), a = (-1.1, 1.1)))
# end
@test length(trajectory_history) == 21
end

@testitem "Callback with full parameter test" begin
using MathOptInterface
using NamedTrajectories
const MOI = MathOptInterface

T = 50
Δt = 0.2
sys = QuantumSystem(0.1 * GATES[:Z], [GATES[:X], GATES[:Y]])
ψ_init = [1.0, 0.0]
ψ_target = [0.0, 1.0]

# Single initial and target states
# --------------------------------
prob = QuantumStateSmoothPulseProblem(
sys, ψ_init, ψ_target, T, Δt;
ipopt_options=IpoptOptions(print_level=1),
piccolo_options=PiccoloOptions(verbose=false)
)

obj_vals = []
function get_history_callback(
alg_mod::Cint,
iter_count::Cint,
obj_value::Float64,
inf_pr::Float64,
inf_du::Float64,
mu::Float64,
d_norm::Float64,
regularization_size::Float64,
alpha_du::Float64,
alpha_pr::Float64,
ls_trials::Cint,
)
push!(obj_vals, obj_value)
return iter_count < 3
end

solve!(prob, max_iter=20, callback=get_history_callback)

@test MOI.get(prob.optimizer, MOI.TerminationStatus()) == MOI.INTERRUPTED
@test length(obj_vals) == 4 # problem init, iter 1, iter 2, iter 3 (terminate)
end
Loading