diff --git a/src/problem_templates/_problem_templates.jl b/src/problem_templates/_problem_templates.jl index abb41cf9..5cbd67b2 100644 --- a/src/problem_templates/_problem_templates.jl +++ b/src/problem_templates/_problem_templates.jl @@ -32,7 +32,7 @@ include("unitary_bang_bang_problem.jl") include("quantum_state_smooth_pulse_problem.jl") include("quantum_state_minimum_time_problem.jl") - +include("quantum_state_sampling_problem.jl") function apply_piccolo_options!( J::Objective, diff --git a/src/problem_templates/quantum_state_sampling_problem.jl b/src/problem_templates/quantum_state_sampling_problem.jl new file mode 100644 index 00000000..6d8e47d2 --- /dev/null +++ b/src/problem_templates/quantum_state_sampling_problem.jl @@ -0,0 +1,209 @@ +export QuantumStateSamplingProblem + +function QuantumStateSamplingProblem end + +function QuantumStateSamplingProblem( + systems::AbstractVector{<:AbstractQuantumSystem}, + ψ_inits::Vector{<:AbstractVector{<:ComplexF64}}, + ψ_goals::Vector{<:AbstractVector{<:ComplexF64}}, + T::Int, + Δt::Union{Float64,Vector{Float64}}; + system_weights=fill(1.0, length(systems)), + init_trajectory::Union{NamedTrajectory,Nothing}=nothing, + ipopt_options::IpoptOptions=IpoptOptions(), + piccolo_options::PiccoloOptions=PiccoloOptions(), + state_name::Symbol=:ψ̃, + control_name::Symbol=:a, + timestep_name::Symbol=:Δt, + a_bound::Float64=1.0, + a_bounds=fill(a_bound, length(systems[1].G_drives)), + a_guess::Union{Matrix{Float64},Nothing}=nothing, + da_bound::Float64=Inf, + da_bounds::Vector{Float64}=fill(da_bound, length(systems[1].G_drives)), + dda_bound::Float64=1.0, + dda_bounds::Vector{Float64}=fill(dda_bound, length(systems[1].G_drives)), + Δt_min::Float64=0.5 * Δt, + Δt_max::Float64=1.5 * Δt, + drive_derivative_σ::Float64=0.01, + Q::Float64=100.0, + R=1e-2, + R_a::Union{Float64,Vector{Float64}}=R, + R_da::Union{Float64,Vector{Float64}}=R, + R_dda::Union{Float64,Vector{Float64}}=R, + leakage_operator::Union{Nothing, EmbeddedOperator}=nothing, + constraints::Vector{<:AbstractConstraint}=AbstractConstraint[], + kwargs... +) + @assert length(ψ_inits) == length(ψ_goals) + + # Trajectory + if !isnothing(init_trajectory) + traj = init_trajectory + else + n_drives = length(systems[1].G_drives) + + traj = initialize_trajectory( + ψ_goals, + ψ_inits, + T, + Δt, + n_drives, + (a_bounds, da_bounds, dda_bounds); + state_name=state_name, + control_name=control_name, + timestep_name=timestep_name, + free_time=piccolo_options.free_time, + Δt_bounds=(Δt_min, Δt_max), + bound_state=piccolo_options.bound_state, + drive_derivative_σ=drive_derivative_σ, + a_guess=a_guess, + system=systems, + rollout_integrator=piccolo_options.rollout_integrator, + ) + end + + # Outer dimension is the system, inner dimension is the initial state + state_names = [ + name for name ∈ traj.names + if startswith(string(name), string(state_name)) + ] + @assert length(ψ_inits) * length(systems) == length(state_names) "State names do not match number of systems and initial states" + state_names = reshape(state_names, length(ψ_inits), length(systems)) + + control_names = [ + name for name ∈ traj.names + if endswith(string(name), string(control_name)) + ] + + # Objective + J = QuadraticRegularizer(control_names[1], traj, R_a; timestep_name=timestep_name) + J += QuadraticRegularizer(control_names[2], traj, R_da; timestep_name=timestep_name) + J += QuadraticRegularizer(control_names[3], traj, R_dda; timestep_name=timestep_name) + + for (weight, names) in zip(system_weights, eachcol(state_names)) + for name in names + J += weight * QuantumStateObjective(name, traj, Q) + end + end + + # Integrators + state_integrators = [] + for (system, names) ∈ zip(systems, eachcol(state_names)) + for name ∈ names + if piccolo_options.integrator == :pade + state_integrator = QuantumStatePadeIntegrator( + system, name, control_name, traj; + order=piccolo_options.pade_order + ) + elseif piccolo_options.integrator == :exponential + state_integrator = QuantumStateExponentialIntegrator( + system, name, control_name, traj + ) + else + error("integrator must be one of (:pade, :exponential)") + end + push!(state_integrators, state_integrator) + end + end + + integrators = [ + state_integrators..., + DerivativeIntegrator(control_name, control_names[2], traj), + DerivativeIntegrator(control_names[2], control_names[3], traj), + ] + + # Optional Piccolo constraints and objectives + apply_piccolo_options!( + J, constraints, piccolo_options, traj, leakage_operator, state_name, timestep_name + ) + + return QuantumControlProblem( + direct_sum(systems), + traj, + J, + integrators; + constraints=constraints, + ipopt_options=ipopt_options, + piccolo_options=piccolo_options, + kwargs... + ) +end + +function QuantumStateSamplingProblem( + systems::AbstractVector{<:AbstractQuantumSystem}, + ψ_init::AbstractVector{<:ComplexF64}, + ψ_goal::AbstractVector{<:ComplexF64}, + args...; + kwargs... +) + return QuantumStateSamplingProblem(systems, [ψ_init], [ψ_goal], args...; kwargs...) +end + + +# ============================================================================= + +@testitem "Sample systems with single initial, target" begin + # System + T = 50 + Δt = 0.2 + sys1 = QuantumSystem(0.3 * GATES[:Z], [GATES[:X], GATES[:Y]]) + sys2 = QuantumSystem(0.0 * GATES[:Z], [GATES[:X], GATES[:Y]]) + + # Single initial and target states + # -------------------------------- + ψ_init = Vector{ComplexF64}([1.0, 0.0]) + ψ_target = Vector{ComplexF64}([0.0, 1.0]) + + prob = QuantumStateSamplingProblem( + [sys1, sys2], ψ_init, ψ_target, T, Δt; + ipopt_options=IpoptOptions(print_level=1), + piccolo_options=PiccoloOptions(verbose=false) + ) + + state_names = [n for n ∈ prob.trajectory.names if startswith(string(n), "ψ̃")] + + init = [fidelity(prob.trajectory, sys1, state_symb=n) for n in state_names] + solve!(prob, max_iter=20) + final = [fidelity(prob.trajectory, sys1, state_symb=n) for n in state_names] + @test all(final .> init) + + # Compare a solution without robustness + # ------------------------------------- + prob_default = QuantumStateSmoothPulseProblem( + sys2, ψ_init, ψ_target, T, Δt; + ipopt_options=IpoptOptions(print_level=1), + piccolo_options=PiccoloOptions(verbose=false), + robustness=false + ) + solve!(prob_default, max_iter=20) + final_default = fidelity(prob_default.trajectory, sys1) + final_robust = fidelity(prob.trajectory, sys1, state_symb=state_names[1]) + @test final_robust > final_default +end + +@testitem "Sample systems with multiple initial, target" begin + # System + T = 50 + Δt = 0.2 + sys1 = QuantumSystem(0.3 * GATES[:Z], [GATES[:X], GATES[:Y]]) + sys2 = QuantumSystem(0.0 * GATES[:Z], [GATES[:X], GATES[:Y]]) + + # Multiple initial and target states + # ---------------------------------- + ψ_inits = Vector{ComplexF64}.([[1.0, 0.0], [0.0, 1.0]]) + ψ_targets = Vector{ComplexF64}.([[0.0, 1.0], [1.0, 0.0]]) + + prob = QuantumStateSamplingProblem( + [sys1, sys2], ψ_inits, ψ_targets, T, Δt; + ipopt_options=IpoptOptions(print_level=1), + piccolo_options=PiccoloOptions(verbose=false) + ) + + state_names = [n for n ∈ prob.trajectory.names if startswith(string(n), "ψ̃")] + + init = [fidelity(prob.trajectory, sys1, state_symb=n) for n in state_names] + solve!(prob, max_iter=20) + final = [fidelity(prob.trajectory, sys1, state_symb=n) for n in state_names] + @test all(final .> init) +end + diff --git a/src/problem_templates/quantum_state_smooth_pulse_problem.jl b/src/problem_templates/quantum_state_smooth_pulse_problem.jl index 70623d96..21b8cb7f 100644 --- a/src/problem_templates/quantum_state_smooth_pulse_problem.jl +++ b/src/problem_templates/quantum_state_smooth_pulse_problem.jl @@ -104,6 +104,7 @@ function QuantumStateSmoothPulseProblem( timestep_name=timestep_name, free_time=piccolo_options.free_time, Δt_bounds=(Δt_min, Δt_max), + bound_state=piccolo_options.bound_state, drive_derivative_σ=drive_derivative_σ, a_guess=a_guess, system=system, @@ -111,76 +112,42 @@ function QuantumStateSmoothPulseProblem( ) end - # Objective + state_names = [ + name for name ∈ traj.names + if startswith(string(name), string(state_name)) + ] + @assert length(state_names) == length(ψ_inits) "Number of states must match number of initial states" + control_names = [ name for name ∈ traj.names if endswith(string(name), string(control_name)) ] + # Objective J = QuadraticRegularizer(control_names[1], traj, R_a; timestep_name=timestep_name) J += QuadraticRegularizer(control_names[2], traj, R_da; timestep_name=timestep_name) J += QuadraticRegularizer(control_names[3], traj, R_dda; timestep_name=timestep_name) - if length(ψ_inits) == 1 - J += QuantumStateObjective(state_name, traj, Q) - else - state_names = [ - name for name ∈ traj.names - if startswith(string(name), string(state_name)) - ] - @assert length(state_names) == length(ψ_inits) "Number of states must match number of initial states" - for i = 1:length(ψ_inits) - J += QuantumStateObjective(state_names[i], traj, Q) - end + for name ∈ state_names + J += QuantumStateObjective(name, traj, Q) end # Integrators - if length(ψ_inits) == 1 + state_integrators = [] + for name ∈ state_names if piccolo_options.integrator == :pade - state_integrators = [QuantumStatePadeIntegrator( - system, - state_name, - control_name, - traj; + state_integrator = QuantumStatePadeIntegrator( + system, name, control_name, traj; order=piccolo_options.pade_order - )] + ) elseif piccolo_options.integrator == :exponential - state_integrators = [QuantumStateExponentialIntegrator( - system, - state_name, - control_name, - traj - )] + state_integrator = QuantumStateExponentialIntegrator( + system, name, control_name, traj + ) else error("integrator must be one of (:pade, :exponential)") end - else - state_names = [ - name for name ∈ traj.names - if startswith(string(name), string(state_name)) - ] - state_integrators = [] - for i = 1:length(ψ_inits) - if piccolo_options.integrator == :pade - state_integrator = QuantumStatePadeIntegrator( - system, - state_names[i], - control_name, - traj; - order=piccolo_options.pade_order - ) - elseif piccolo_options.integrator == :exponential - state_integrator = QuantumStateExponentialIntegrator( - system, - state_names[i], - control_name, - traj - ) - else - error("integrator must be one of (:pade, :exponential)") - end - push!(state_integrators, state_integrator) - end + push!(state_integrators, state_integrator) end integrators = [ @@ -191,13 +158,7 @@ function QuantumStateSmoothPulseProblem( # Optional Piccolo constraints and objectives apply_piccolo_options!( - J, - constraints, - piccolo_options, - traj, - leakage_operator, - state_name, - timestep_name + J, constraints, piccolo_options, traj, leakage_operator, state_name, timestep_name ) return QuantumControlProblem( diff --git a/src/problem_templates/unitary_sampling_problem.jl b/src/problem_templates/unitary_sampling_problem.jl index 6d32fd20..5aead037 100644 --- a/src/problem_templates/unitary_sampling_problem.jl +++ b/src/problem_templates/unitary_sampling_problem.jl @@ -47,7 +47,6 @@ function UnitarySamplingProblem( operator::OperatorType, T::Int, Δt::Union{Float64,Vector{Float64}}; - system_labels=string.(1:length(systems)), system_weights=fill(1.0, length(systems)), init_trajectory::Union{NamedTrajectory,Nothing}=nothing, ipopt_options::IpoptOptions=IpoptOptions(), @@ -73,9 +72,6 @@ function UnitarySamplingProblem( R_dda::Union{Float64,Vector{Float64}}=R, kwargs... ) - # Create keys for multiple systems - Ũ⃗_names = [add_suffix(state_name, ℓ) for ℓ ∈ system_labels] - # Trajectory if !isnothing(init_trajectory) traj = init_trajectory @@ -99,71 +95,43 @@ function UnitarySamplingProblem( a_guess=a_guess, system=systems, rollout_integrator=piccolo_options.rollout_integrator, - state_names=Ũ⃗_names ) end + state_names = [ + name for name ∈ traj.names + if startswith(string(name), string(state_name)) + ] + control_names = [ name for name ∈ traj.names if endswith(string(name), string(control_name)) ] # Objective - J = NullObjective() - for (wᵢ, Ũ⃗_name) in zip(system_weights, Ũ⃗_names) - J += wᵢ * UnitaryInfidelityObjective( - Ũ⃗_name, traj, Q; - subspace=operator isa EmbeddedOperator ? operator.subspace_indices : nothing - ) - end - J += QuadraticRegularizer(control_names[1], traj, R_a; timestep_name=timestep_name) + J = QuadraticRegularizer(control_names[1], traj, R_a; timestep_name=timestep_name) J += QuadraticRegularizer(control_names[2], traj, R_da; timestep_name=timestep_name) J += QuadraticRegularizer(control_names[3], traj, R_dda; timestep_name=timestep_name) - # Constraints - if piccolo_options.leakage_suppression - if operator isa EmbeddedOperator - leakage_indices = get_iso_vec_leakage_indices(operator) - for Ũ⃗_name in Ũ⃗_names - J += L1Regularizer!( - constraints, Ũ⃗_name, traj, - R_value=piccolo_options.R_leakage, - indices=leakage_indices, - eval_hessian=piccolo_options.eval_hessian - ) - end - else - @warn "leakage_suppression is not supported for non-embedded operators, ignoring." - end - end - - if piccolo_options.free_time - if piccolo_options.timesteps_all_equal - push!(constraints, TimeStepsAllEqualConstraint(:Δt, traj)) - end - end - - if !isnothing(piccolo_options.complex_control_norm_constraint_name) - norm_con = ComplexModulusContraint( - piccolo_options.complex_control_norm_constraint_name, - piccolo_options.complex_control_norm_constraint_radius, - traj; + for (weight, name) in zip(system_weights, state_names) + J += weight * UnitaryInfidelityObjective( + name, traj, Q; + subspace=operator isa EmbeddedOperator ? operator.subspace_indices : nothing ) - push!(constraints, norm_con) end # Integrators unitary_integrators = AbstractIntegrator[] - for (sys, Ũ⃗_name) in zip(systems, Ũ⃗_names) + for (sys, name) in zip(systems, state_names) if piccolo_options.integrator == :pade push!( unitary_integrators, - UnitaryPadeIntegrator(sys, Ũ⃗_name, control_name, traj; order=piccolo_options.pade_order) + UnitaryPadeIntegrator(sys, name, control_name, traj; order=piccolo_options.pade_order) ) elseif piccolo_options.integrator == :exponential push!( unitary_integrators, - UnitaryExponentialIntegrator(sys, Ũ⃗_name, control_name, traj) + UnitaryExponentialIntegrator(sys, name, control_name, traj) ) else error("integrator must be one of (:pade, :exponential)") @@ -176,6 +144,11 @@ function UnitarySamplingProblem( DerivativeIntegrator(control_names[2], control_names[3], traj), ] + # Optional Piccolo constraints and objectives + apply_piccolo_options!( + J, constraints, piccolo_options, traj, operator, state_name, timestep_name + ) + return QuantumControlProblem( direct_sum(systems), traj, @@ -262,7 +235,7 @@ end piccolo_options=pi_ops ) - @test g1_prob.trajectory.Ũ⃗1 ≈ g1_prob.trajectory.Ũ⃗2 + @test g1_prob.trajectory.Ũ⃗_system_1 ≈ g1_prob.trajectory.Ũ⃗_system_2 g2_prob = UnitarySamplingProblem( [systems(0), systems(0.05)], operator, T, Δt; @@ -270,5 +243,5 @@ end piccolo_options=pi_ops ) - @test ~(g2_prob.trajectory.Ũ⃗1 ≈ g2_prob.trajectory.Ũ⃗2) + @test ~(g2_prob.trajectory.Ũ⃗_system_1 ≈ g2_prob.trajectory.Ũ⃗_system_2) end diff --git a/src/problem_templates/unitary_smooth_pulse_problem.jl b/src/problem_templates/unitary_smooth_pulse_problem.jl index 10986998..2a7cfe0a 100644 --- a/src/problem_templates/unitary_smooth_pulse_problem.jl +++ b/src/problem_templates/unitary_smooth_pulse_problem.jl @@ -158,7 +158,8 @@ function UnitarySmoothPulseProblem( # Integrators if piccolo_options.integrator == :pade unitary_integrator = - UnitaryPadeIntegrator(system, state_name, control_name, traj; + UnitaryPadeIntegrator( + system, state_name, control_name, traj; order=piccolo_options.pade_order ) elseif piccolo_options.integrator == :exponential @@ -175,7 +176,9 @@ function UnitarySmoothPulseProblem( ] # Optional Piccolo constraints and objectives - apply_piccolo_options!(J, constraints, piccolo_options, traj, operator, state_name, timestep_name) + apply_piccolo_options!( + J, constraints, piccolo_options, traj, operator, state_name, timestep_name + ) return QuantumControlProblem( system, diff --git a/src/trajectory_initialization.jl b/src/trajectory_initialization.jl index f6e1bc6b..f78f0bb3 100644 --- a/src/trajectory_initialization.jl +++ b/src/trajectory_initialization.jl @@ -252,15 +252,15 @@ initialize_control_trajectory(a::AbstractMatrix, Δt::Real, n_derivatives::Int) # ----------------------------------------------------------------------------- # """ - initialize_unitary_trajectory + initialize_trajectory -Initialize a trajectory for a unitary control problem. The trajectory is initialized with +Initialize a trajectory for a control problem. The trajectory is initialized with data that should be consistently the same type (in this case, Float64). """ function initialize_trajectory( - state_data::Vector{Matrix{Float64}}, + state_data::Vector{<:AbstractMatrix{Float64}}, state_inits::Vector{<:AbstractVector{Float64}}, state_goals::Vector{<:AbstractVector{Float64}}, state_names::AbstractVector{Symbol}, @@ -277,11 +277,9 @@ function initialize_trajectory( drive_derivative_σ::Float64=0.1, a_guess::Union{AbstractMatrix{<:Float64}, Nothing}=nothing, global_data::Union{NamedTuple, Nothing}=nothing, + verbose=false, ) - @assert !isnothing(state_inits) "state_init must be provided" - @assert !isnothing(state_goals) "state_goal must be provided" @assert length(state_data) == length(state_names) == length(state_inits) == length(state_goals) "state_data, state_names, state_inits, and state_goals must have the same length" - @assert length(control_bounds) == n_control_derivatives + 1 "control_bounds must have $n_control_derivatives + 1 elements" # assert that state names are unique @@ -289,9 +287,11 @@ function initialize_trajectory( # Control data control_derivative_names = [ - Symbol("d"^i * string(control_name)) - for i = 1:n_control_derivatives + Symbol("d"^i * string(control_name)) for i = 1:n_control_derivatives ] + if verbose + println("control_derivative_names: $control_derivative_names") + end control_names = (control_name, control_derivative_names...) @@ -312,7 +312,6 @@ function initialize_trajectory( timestep = Δt end - # Constraints initial = (; (state_names .=> state_inits)..., @@ -374,18 +373,56 @@ function initialize_trajectory( ) end +""" + Initialize the unitary states for a single set of parameters. + +""" +function initialize_unitaries( + U_goal::OperatorType, + T::Int, + timesteps::AbstractVector{<:Real}; + U_init::AbstractMatrix{<:Number}=Matrix{ComplexF64}(I(size(U_goal, 1))), + a_guess::Union{AbstractMatrix{<:Float64}, Nothing}=nothing, + system::Union{AbstractQuantumSystem, Nothing}=nothing, + rollout_integrator::Function=expv, + geodesic=true, +)::Matrix{Float64} + Ũ⃗_init = operator_to_iso_vec(U_init) + + if U_goal isa EmbeddedOperator + Ũ⃗_goal = operator_to_iso_vec(U_goal.operator) + else + Ũ⃗_goal = operator_to_iso_vec(U_goal) + end + + # Construct state data + if isnothing(a_guess) + state_data = initialize_unitary_trajectory(U_init, U_goal, T; geodesic=geodesic) + else + @assert !isnothing(system) "System must be provided if a_guess is provided." + state_data = unitary_rollout(Ũ⃗_init, a_guess, timesteps, system; integrator=rollout_integrator) + end + + return state_data +end + +""" + initialize_trajectory + +Trajectory initialization of unitary states can broadcast over multiple systems. +""" function initialize_trajectory( U_goal::OperatorType, T::Int, Δt::Union{Real, AbstractVecOrMat{<:Real}}, args...; state_name::Symbol=:Ũ⃗, - state_names::AbstractVector{<:Symbol}=[state_name], U_init::AbstractMatrix{<:Number}=Matrix{ComplexF64}(I(size(U_goal, 1))), a_guess::Union{AbstractMatrix{<:Float64}, Nothing}=nothing, system::Union{AbstractQuantumSystem, AbstractVector{<:AbstractQuantumSystem}, Nothing}=nothing, rollout_integrator::Function=expv, geodesic=true, + verbose=false, kwargs... ) Ũ⃗_init = operator_to_iso_vec(U_init) @@ -396,50 +433,48 @@ function initialize_trajectory( Ũ⃗_goal = operator_to_iso_vec(U_goal) end - if isnothing(a_guess) - # No guess provided, initialize a geodesic and randomly sample controls - - Ũ⃗_traj = initialize_unitary_trajectory(U_init, U_goal, T; geodesic=geodesic) - - state_data = repeat([Ũ⃗_traj], length(state_names)) + # Construct timesteps + if Δt isa AbstractMatrix + timesteps = vec(Δt) + elseif Δt isa Float64 + timesteps = fill(Δt, T) else - if system isa AbstractQuantumSystem - @assert size(a_guess, 1) == length(system.H_drives) "a_guess must have the same number of drives as n_drives" - elseif system isa AbstractVector - @assert size(a_guess, 1) == length(system[1].H_drives) "a_guess must have the same number of drives as n_drives" - end - - @assert size(a_guess, 2) == T "a_guess must have the same number of timesteps as T" - @assert !isnothing(system) "system must be provided if a_guess is provided" + timesteps = Δt + end - if Δt isa AbstractMatrix - timesteps = vec(Δt) - elseif Δt isa Float64 - timesteps = fill(Δt, T) - else - timesteps = Δt + # Construct state data + if system isa AbstractVector + state_data = map(system) do sys + initialize_unitaries( + U_goal, T, timesteps; + U_init=U_init, + a_guess=a_guess, + system=sys, + rollout_integrator=rollout_integrator, + geodesic=geodesic + ) end - if system isa AbstractVector - @assert length(system) == length(state_names) "systems must have the same length as state_names" - state_data = map(system) do sys - unitary_rollout(Ũ⃗_init, a_guess, timesteps, sys; integrator=rollout_integrator) - end - else - state = unitary_rollout(Ũ⃗_init, a_guess, timesteps, system; integrator=rollout_integrator) - state_data = [state] + state_names = Symbol.([string(state_name) * "_system_$i" for i in eachindex(system)]) + if verbose + println("Created state names for ($(length(system))) systems: $state_names") end - state_data = Matrix{Float64}.(state_data) - end - - if system isa AbstractVector && length(state_names) != length(system) - state_names = [string(state_name) * "_system_$i" for i = 1:length(system)] - @warn "length of state_names and number of systems ($(length(system))) are not equal, created state names for each system: " state_names + state_inits = repeat([Ũ⃗_init], length(system)) + state_goals = repeat([Ũ⃗_goal], length(system)) + else + state_data = [initialize_unitaries( + U_goal, T, timesteps; + U_init=U_init, + a_guess=a_guess, + system=system, + rollout_integrator=rollout_integrator, + geodesic=geodesic + )] + state_names = [state_name] + state_inits = [Ũ⃗_init] + state_goals = [Ũ⃗_goal] end - state_inits = repeat([Ũ⃗_init], length(state_names)) - state_goals = repeat([Ũ⃗_goal], length(state_names)) - return initialize_trajectory( state_data, state_inits, @@ -449,10 +484,51 @@ function initialize_trajectory( Δt, args...; a_guess=a_guess, + verbose=verbose, kwargs... ) end +""" + initialize_quantum_states + +Initialize the quantum states for a single set of parameters. +""" +function initialize_quantum_states( + ψ̃_goals::AbstractVector{<:AbstractVector{Float64}}, + ψ̃_inits::AbstractVector{<:AbstractVector{Float64}}, + T::Int, + timesteps::AbstractVector{<:Real}; + a_guess::Union{AbstractMatrix{<:Float64}, Nothing}=nothing, + system::Union{AbstractQuantumSystem, Nothing}=nothing, + rollout_integrator::Function=expv, +) + state_data = Matrix{Float64}[] + if isnothing(a_guess) + for (ψ̃_init, ψ̃_goal) ∈ zip(ψ̃_inits, ψ̃_goals) + ψ̃_traj = linear_interpolation(ψ̃_init, ψ̃_goal, T) + push!(state_data, ψ̃_traj) + end + if system isa AbstractVector + state_data = repeat(state_data, length(system)) + end + else + for sys ∈ system + for ψ̃_init ∈ ψ̃_inits + ψ̃_traj = rollout(ψ̃_init, a_guess, timesteps, sys; integrator=rollout_integrator) + push!(state_data, ψ̃_traj) + end + end + end + + return state_data +end + +""" + initialize_trajectory + +Trajectory initialization of quantum states can broadcast over multiple systems. +""" function initialize_trajectory( ψ_goals::AbstractVector{<:AbstractVector{ComplexF64}}, ψ_inits::AbstractVector{<:AbstractVector{ComplexF64}}, @@ -466,6 +542,7 @@ function initialize_trajectory( a_guess::Union{AbstractMatrix{<:Float64}, Nothing}=nothing, system::Union{AbstractQuantumSystem, AbstractVector{<:AbstractQuantumSystem}, Nothing}=nothing, rollout_integrator::Function=expv, + verbose=false, kwargs... ) @assert length(ψ_inits) == length(ψ_goals) "ψ_inits and ψ_goals must have the same length" @@ -474,66 +551,44 @@ function initialize_trajectory( ψ̃_goals = ket_to_iso.(ψ_goals) ψ̃_inits = ket_to_iso.(ψ_inits) - if isnothing(a_guess) - state_data = [] - for (ψ̃_init, ψ̃_goal) ∈ zip(ψ̃_inits, ψ̃_goals) - ψ̃_traj = linear_interpolation(ψ̃_init, ψ̃_goal, T) - push!(state_data, ψ̃_traj) - end - if system isa AbstractVector - state_data = repeat(state_data, length(system)) - end + if Δt isa AbstractMatrix + timesteps = vec(Δt) + elseif Δt isa Float64 + timesteps = fill(Δt, T) else - @assert size(a_guess, 1) == n_drives "a_guess must have n_drives = $(n_drives) drives" - @assert size(a_guess, 2) == T "a_guess must have T = $(T) timesteps" - @assert !isnothing(system) "system must be provided if a_guess is provided" - - if Δt isa AbstractMatrix - timesteps = vec(Δt) - elseif Δt isa Float64 - timesteps = fill(Δt, T) - else - timesteps = Δt - end - - if system isa AbstractVector - state_data = [] - for sys ∈ system - for ψ̃_init ∈ ψ̃_inits - ψ̃_traj = rollout(ψ̃_init, a_guess, timesteps, sys; - integrator=rollout_integrator - ) - push!(state_data, ψ̃_traj) - end - end - else - state_data = [] - for ψ̃_init ∈ ψ̃_inits - ψ̃_traj = rollout(ψ̃_init, a_guess, timesteps, system; - integrator=rollout_integrator - ) - push!(state_data, ψ̃_traj) - end - end + timesteps = Δt end - state_data = Matrix{Float64}.(state_data) - if system isa AbstractVector - if lenth(state_names) != length(system) * length(ψ_goals) - state_names = vcat([ - Symbol.(string.(state_names) .* "_system_$i") - for i = 1:length(system) - ]...) - @warn "length of state_names and number of systems ($(length(system))) * number of states ($(length(ψ_goals))) are not equal, created state names for each system: " state_names + state_data = map(system) do sys + initialize_quantum_states( + ψ̃_goals, ψ̃_inits, T, timesteps; + a_guess=a_guess, + system=sys, + rollout_integrator=rollout_integrator + ) + end + # Flatten state data along systems + state_data = vcat(state_data...) + # Update state names using systems + state_names = Symbol.([string(n) * "_system_$i" for i in eachindex(system) for n in state_names]) + if verbose + println("Created state names for ($(length(system))) systems: $state_names") end - ψ̃_inits = repeat(ψ̃_inits, length(system)) - ψ̃_goals = repeat(ψ̃_goals, length(system)) + state_inits = repeat(ψ̃_inits, length(system)) + state_goals = repeat(ψ̃_goals, length(system)) + else + state_data = initialize_quantum_states( + ψ̃_goals, ψ̃_inits, T, timesteps; + a_guess=a_guess, + system=system, + rollout_integrator=rollout_integrator + ) + # state_names are the same + state_inits = ψ̃_inits + state_goals = ψ̃_goals end - state_inits = ψ̃_inits - state_goals = ψ̃_goals - return initialize_trajectory( state_data, state_inits, @@ -543,6 +598,7 @@ function initialize_trajectory( Δt, args...; a_guess=a_guess, + verbose=verbose, kwargs... ) end