Skip to content
Prev Previous commit
Next Next commit
Fixed Adjoint
  • Loading branch information
BBhattacharyya1729 committed Mar 26, 2025
commit a53ed73de07243c15f8d070737b47a863a5e85c9
50 changes: 37 additions & 13 deletions src/problem_templates/adjoint_smooth_pulse_problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ function AdjointUnitarySmoothPulseProblem(
goal::AbstractPiccoloOperator,
T::Int,
Δt::Union{Float64, <:AbstractVector{Float64}};
times::Union{AbstractVector,Nothing}=nothing,
times::Vector{<:Union{AbstractVector,Nothing}}=[nothing for s∈system.Gₐ],
down_times::Vector{<:Union{AbstractVector,Nothing}}=[nothing for s∈system.Gₐ],
unitary_integrator=AdjointUnitaryIntegrator,
state_name::Symbol = :Ũ⃗,
state_adjoint_name::Symbol = :Ũ⃗ₐ,
Expand Down Expand Up @@ -36,7 +37,7 @@ function AdjointUnitarySmoothPulseProblem(

# Trajectory
if !isnothing(init_trajectory)
traj = init_trajectory
traj = copy(init_trajectory)
else
traj = initialize_trajectory(
goal,
Expand All @@ -52,21 +53,36 @@ function AdjointUnitarySmoothPulseProblem(
geodesic=piccolo_options.geodesic,
bound_state=piccolo_options.bound_state,
a_guess=a_guess,
system=system,
rollout_integrator=piccolo_options.rollout_integrator,
verbose=piccolo_options.verbose
)
end

adj = traj[state_name]
adj[:,1] *= 0

if !isnothing(init_trajectory)
adj = adjoint_unitary_rollout(traj,system;unitary_name=state_name,drive_name = control_name)[2]
end

add_component!(traj,state_adjoint_name,adj;type=:state)
traj.initial = merge(traj.initial, (state_adjoint_name => adj[:,1], ))
adj = [traj[state_name] for s ∈ system.Gₐ]


# for a∈adj
# a[:,1] *= 0
# end


full_rollout = adjoint_unitary_rollout(traj, system ;unitary_name=state_name,drive_name = control_name)[2]

for i in 1:length(system.Gₐ)
adj[i] = full_rollout[i]
end


state_adjoint_names = [add_suffix(state_adjoint_name,string(i)) for i in 1:length(system.Gₐ)]


for (name,a) in zip(state_adjoint_names,adj)
add_component!(traj,name,a;type=:state)
traj.initial = merge(traj.initial, (name=> a[:,1], ))

end

fidelity_constraint = FinalUnitaryFidelityConstraint(
goal, state_name, final_fidelity, traj
Expand All @@ -83,8 +99,16 @@ function AdjointUnitarySmoothPulseProblem(
J += QuadraticRegularizer(control_names[2], traj, R_da)
J += QuadraticRegularizer(control_names[3], traj, R_dda)

if(!isnothing(times))
J += UnitaryNormLoss(state_adjoint_name, traj, times; Q)
for (name,t) ∈ zip(state_adjoint_names,times)
if(!isnothing(t))
J += UnitaryNormLoss(name, traj, t; Q=Q)
end
end

for (name,down_t) ∈ zip(state_adjoint_names,down_times)
if(!isnothing(down_t))
J += UnitaryNormLoss(name, traj, down_t; Q=Q, rep = false)
end
end

# Optional Piccolo constraints and objectives
Expand All @@ -94,7 +118,7 @@ function AdjointUnitarySmoothPulseProblem(
)

integrators = [
unitary_integrator(system, traj, state_name,state_adjoint_name,control_name),
unitary_integrator(system, traj, state_name,state_adjoint_names,control_name),
DerivativeIntegrator(traj, control_name, control_names[2]),
DerivativeIntegrator(traj, control_names[2], control_names[3]),
]
Expand Down
11 changes: 9 additions & 2 deletions src/quantum_integrators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,17 @@ function AdjointUnitaryIntegrator(
sys::ParameterizedQuantumSystem,
traj::NamedTrajectory,
Ũ⃗::Symbol,
Ũ⃗ₐ::Symbol,
Ũ⃗ₐ::Vector{Symbol},
a::Symbol
)
Ĝ = a_ -> [I(sys.levels) ⊗ sys.G(a_) I(sys.levels) ⊗ sys.Gₐ(a_) ; I(sys.levels) ⊗ sys.G(a_)*0 I(sys.levels) ⊗ sys.G(a_) ]
n_sys = length(sys.Gₐ)

G = a_ -> I(sys.levels) ⊗ sys.G(a_)

Gai = (i,a_) -> I(sys.levels) ⊗ sys.Gₐ[i](a_)

Ĝ = a_ -> vcat(reduce(vcat,[[zeros(size(G(a_))[1],(i-1) * size(G(a_))[1]) G(a_) zeros(size(G(a_))[1],(n_sys-i) * size(G(a_))[1]) Gai(i,a_)] for i in 1:length(sys.Gₐ)]),[zeros(size(G(a_))[1],size(G(a_))[1]*n_sys) G(a_)])

return AdjointBilinearIntegrator(Ĝ, traj, Ũ⃗, Ũ⃗ₐ, a)
end

Expand Down
10 changes: 8 additions & 2 deletions src/quantum_objectives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,14 @@ function UnitaryNormLoss(
name::Symbol,
traj::NamedTrajectory,
times::AbstractVector;
Q::Float64=100.0
Q::Float64=100.0,
rep = true
)
ℓ = Ũ⃗-> 1/(Ũ⃗'Ũ⃗) * length(Ũ⃗)
if(rep)
ℓ = Ũ⃗-> 1/(Ũ⃗'Ũ⃗) * length(Ũ⃗)
else
ℓ = Ũ⃗-> (Ũ⃗'Ũ⃗) / length(Ũ⃗)
end
return KnotPointObjective(
ℓ,
name,
Expand All @@ -113,4 +118,5 @@ function UnitaryNormLoss(
)
end


end