Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
test and document variational rollouts
  • Loading branch information
andgoldschmidt committed Apr 3, 2025
commit c0a4fe3388901ea955c519938f43dd0cd3c9fab7
207 changes: 182 additions & 25 deletions src/rollouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,28 @@ rollout(ψ::Vector{<:Complex}, args...; kwargs...) =
function rollout(
inits::AbstractVector{<:AbstractVector}, args...; kwargs...
)
return vcat([rollout(state, args...; kwargs...) for state ∈ inits]...)
return [rollout(state, args...; kwargs...) for state ∈ inits]
end

function rollout(
traj::NamedTrajectory,
system::AbstractQuantumSystem;
state_name::Symbol=:ψ̃,
drive_name::Symbol=:a,
kwargs...
)
# Get the initial state names
state_names = [
name for name ∈ traj.names if startswith(string(name), string(state_name))
]

return rollout(
length(state_names) == 1 ? traj.initial[state_name] : [traj.initial[name] for name ∈ state_names],
traj[drive_name],
get_timesteps(traj),
system;
kwargs...
)
end

"""
Expand Down Expand Up @@ -643,7 +664,38 @@ end
# ----------------------------------------------------------------------------- #

"""
variational_rollout(
ψ̃_init::AbstractVector{<:Real},
controls::AbstractMatrix{<:Real},
Δt::AbstractVector{<:Real},
system::VariationalQuantumSystem;
show_progress::Bool=false,
integrator::Function=expv,
exp_vector_product::Bool=infer_is_evp(integrator)
)
variational_rollout(ψ::Vector{<:Complex}, args...; kwargs...)
variational_rollout(inits::AbstractVector{<:AbstractVector}, args...; kwargs...)
variational_rollout(
traj::NamedTrajectory,
system::AbstractQuantumSystem;
state_name::Symbol=:ψ̃,
drive_name::Symbol=:a,
kwargs...
)

Simulates the variational evolution of a quantum state under a given control trajectory.

# Returns
- `Ψ̃::Matrix{<:Real}`: The evolved quantum state at each timestep.
- `Ψ̃_vars::Vector{<:Matrix{<:Real}}`: The variational derivatives of the
quantum state with respect to the variational parameters.

# Notes
This function computes the variational evolution of a quantum state using the
variational generators of the system. It supports autodifferentiable controls and
timesteps, making it suitable for optimization tasks. The variational derivatives are
computed alongside the state evolution, enabling sensitivity analysis and gradient-based
optimization.
"""
function variational_rollout end

Expand Down Expand Up @@ -683,18 +735,96 @@ function variational_rollout(
next!(p)
end

# collect into list of matrices
Ψ̃_vars = collect.(eachslice(reshape(Ψ̃_vars, (:, size(Ψ̃)...)), dims=1))
# collect into vector of matrices
if length(system.G_vars) > 1
Ψ̃_vars = collect.(eachslice(reshape(Ψ̃_vars, (:, size(Ψ̃)...)), dims=1))
end

return Ψ̃, Ψ̃_vars
end

variational_rollout(ψ::Vector{<:Complex}, args...; kwargs...) =
variational_rollout(ket_to_iso(ψ), args...; kwargs...)

function variational_rollout(
inits::AbstractVector{<:AbstractVector}, args...; kwargs...
)
N = length(inits)

# First call
ψ̃1, ψ̃_vars1 = variational_rollout(inits[1], args...; kwargs...)

# Preallocate the rest
ψ̃s = Vector{typeof(ψ̃1)}(undef, N)
ψ̃_vars = Vector{typeof(ψ̃_vars1)}(undef, N)
ψ̃s[1] = ψ̃1
ψ̃_vars[1] = ψ̃_vars1
for i = 2:N
ψ̃s[i], ψ̃_vars[i] = variational_rollout(inits[i], args...; kwargs...)
end
return ψ̃s, ψ̃_vars
end


function variational_rollout(
traj::NamedTrajectory,
system::AbstractQuantumSystem;
state_name::Symbol=:ψ̃,
drive_name::Symbol=:a,
kwargs...
)
# Get the initial state names
state_names = [
name for name ∈ traj.names if startswith(string(name), string(state_name))
]

return variational_rollout(
length(state_names) == 1 ? traj.initial[state_name] : [traj.initial[name] for name ∈ state_names],
traj[drive_name],
get_timesteps(traj),
system;
kwargs...
)
end


"""
variational_unitary_rollout(
Ũ⃗_init::AbstractVector{<:Real},
controls::AbstractMatrix{<:Real},
Δt::AbstractVector{<:Real},
system::VariationalQuantumSystem;
show_progress::Bool=false,
integrator::Function=expv,
exp_vector_product::Bool=infer_is_evp(integrator)
)
variational_unitary_rollout(
controls::AbstractMatrix{<:Real},
Δt::AbstractVector,
system::VariationalQuantumSystem;
kwargs...
)
variational_unitary_rollout(
traj::NamedTrajectory,
system::VariationalQuantumSystem;
unitary_name::Symbol=:Ũ⃗,
drive_name::Symbol=:a,
kwargs...
)

Simulates the variational evolution of a quantum state under a given control trajectory.

# Returns
- `Ũ⃗::Matrix{<:Real}`: The evolved unitary at each timestep.
- `Ũ⃗_vars::Vector{<:Matrix{<:Real}}`: The variational derivatives of the unitary with
respect to the variational parameters.

# Notes
This function computes the variational evolution of a unitary using the
variational generators of the system. It supports autodifferentiable controls and
timesteps, making it suitable for optimization tasks. The variational derivatives are
computed alongside the state evolution, enabling sensitivity analysis and gradient-based
optimization.
"""
function variational_unitary_rollout end

Expand Down Expand Up @@ -737,8 +867,10 @@ function variational_unitary_rollout(
next!(p)
end

# collect into list of matrices
Ũ⃗_vars = collect.(eachslice(reshape(Ũ⃗_vars, (:, size(Ũ⃗)...)), dims=1))
# collect into vector of matrices
if length(system.G_vars) > 1
Ũ⃗_vars = collect.(eachslice(reshape(Ũ⃗_vars, (:, size(Ũ⃗)...)), dims=1))
end

return Ũ⃗, Ũ⃗_vars
end
Expand Down Expand Up @@ -783,24 +915,10 @@ end
include("../test/test_utils.jl")

traj = named_trajectory_type_1()

sys = QuantumSystem(0 * GATES[:Z], [GATES[:X], GATES[:Y]])

U_goal = GATES[:H]

sys = QuantumSystem([PAULIS.X, PAULIS.Y])
U_goal = GATES.H
embedded_U_goal = EmbeddedOperator(U_goal, sys)

# T = 51
# Δt = 0.2
# ts = fill(Δt, T)
# as = collect([π/(T-1)/Δt * sin.(π*(0:T-1)/(T-1)).^2 zeros(T)]')

# prob = UnitarySmoothPulseProblem(
# sys, U_goal, T, Δt, a_guess=as,
# ipopt_options=IpoptOptions(print_level=1),
# piccolo_options=PiccoloOptions(verbose=false, free_time=false)
# )

ψ = ComplexF64[1, 0]
ψ_goal = U_goal * ψ
ψ̃ = ket_to_iso(ψ)
Expand Down Expand Up @@ -855,7 +973,7 @@ end
using ForwardDiff
using ExponentialAction

sys = QuantumSystem(0 * GATES[:Z], [GATES[:X], GATES[:Y]])
sys = QuantumSystem([PAULIS.X, PAULIS.Y])
T = 51
Δt = 0.2
ts = fill(Δt, T)
Expand Down Expand Up @@ -890,9 +1008,48 @@ end
@test size(result2) == (iso_vec_dim, T)
end

@testitem "Variational rollouts" begin
# TODO: Add variational rollout tests
@test_broken false
@testitem "Test variational rollouts" begin
include("../test/test_utils.jl")
sys = QuantumSystem([PAULIS.X, PAULIS.Y])
varsys1 = VariationalQuantumSystem([PAULIS.X, PAULIS.Y], [PAULIS.X])
varsys2 = VariationalQuantumSystem([PAULIS.X, PAULIS.Y], [PAULIS.X, PAULIS.Y])
U_goal = GATES.H

# state rollouts
traj = named_trajectory_type_2()
ψ̃s_def = rollout(traj, sys)
dims = size(ψ̃s_def[1])
for vs in [varsys1, varsys2]
ψ̃s, ψ̃s_vars = variational_rollout(traj, vs)
@assert ψ̃s ≈ ψ̃s_def
if length(vs.G_vars) == 1
for ψ̃_var in ψ̃s_vars
@assert size(ψ̃_var) == dims
end
else
@assert length(ψ̃s_vars[1]) == length(vs.G_vars)
for ψ̃_var in ψ̃s_vars
for ψ̃_var_op in ψ̃_var
@assert size(ψ̃_var_op) == dims
end
end
end
end

# unitary rollouts
traj = named_trajectory_type_1()
Ũ⃗_def = unitary_rollout(traj, sys)

for vs in [varsys1, varsys2]
Ũ⃗, Ũ_vars = variational_unitary_rollout(traj, vs)
@assert Ũ⃗ ≈ Ũ⃗_def
if length(vs.G_vars) == 1
@assert size(Ũ_vars) == size(Ũ⃗_def)
else
@assert length(Ũ_vars) == length(vs.G_vars)
@assert size(Ũ_vars[1]) == size(Ũ⃗_def)
end
end
end


Expand Down
68 changes: 68 additions & 0 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,71 @@ function named_trajectory_type_1(; free_time=false)
goal=goal
)
end

function named_trajectory_type_2()
# X gate (initial states: |0⟩, |1⟩), two dda controls (solved), Δt = 0.2
data = [
0.0 0.0 -0.0 0.057 0.205 0.405 0.58 0.677 0.705 0.707
1.0 1.0 1.0 0.997 0.957 0.82 0.573 0.29 0.08 0.0
0.0 0.0 -0.0 -0.057 -0.205 -0.405 -0.58 -0.677 -0.705 -0.707
0.0 0.0 0.0 0.0 -0.0 -0.0 -0.0 -0.0 -0.0 -0.0
1.0 1.0 1.0 0.997 0.957 0.82 0.573 0.29 0.08 0.0
0.0 0.0 -0.0 -0.057 -0.205 -0.405 -0.58 -0.677 -0.705 -0.707
0.0 0.0 0.0 -0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 -0.0 -0.057 -0.205 -0.405 -0.58 -0.677 -0.705 -0.707
0.0 -0.0 0.145 0.387 0.57 0.634 0.57 0.387 0.145 0.0
0.0 0.0 -0.145 -0.387 -0.57 -0.634 -0.57 -0.387 -0.145 0.0
0.0 0.369 0.618 0.467 0.163 -0.163 -0.467 -0.618 -0.369 0.0
0.0 -0.369 -0.618 -0.467 -0.163 0.163 0.467 0.618 0.369 0.0
0.943 0.636 -0.386 -0.777 -0.833 -0.777 -0.386 0.636 0.943 0.0
-0.943 -0.636 0.386 0.777 0.833 0.777 0.386 -0.636 -0.943 0.0
0.392 0.392 0.392 0.392 0.392 0.392 0.392 0.392 0.392 0.392
]

components = (
ψ̃1 = data[1:4, :],
ψ̃2 = data[5:8, :],
a = data[9:10, :],
da = data[11:12, :],
dda = data[13:14, :],
Δt = data[15:15, :]
)

controls = (:dda, :Δt)

timestep = :Δt

bounds = (
a = ([-1.0, -1.0], [1.0, 1.0]),
da = ([-Inf, -Inf], [Inf, Inf]),
dda = ([-1.0, -1.0], [1.0, 1.0]),
Δt = ([0.0002], [0.4])
)

initial = (
ψ̃1 = [0.0, 1.0, 0.0, 0.0],
ψ̃2 = [1.0, 0.0, 0.0, 0.0],
a = [0.0, 0.0],
da = [0.0, 0.0]
)

final = (
a = [0.0, 0.0],
da = [0.0, 0.0]
)

goal = (
ψ̃1 = [1.0, 0.0, 0.0, 0.0],
ψ̃2 = [0.0, 1.0, 0.0, 0.0]
)

return NamedTrajectory(
components;
controls=controls,
timestep=timestep,
bounds=bounds,
initial=initial,
final=final,
goal=goal
)
end