Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add KnotPointObjective
  • Loading branch information
andgoldschmidt committed Mar 13, 2025
commit 31fa7d12e4805df7cddb986e0453a27f844d33f3
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DirectTrajOpt"
uuid = "c823fa1f-8872-4af5-b810-2b9b72bbbf56"
authors = ["Aaron Trowbridge <[email protected]> and contributors"]
version = "0.1.0"
version = "0.1.1"

[deps]
ExponentialAction = "e24c0720-ea99-47e8-929e-571b494574d3"
Expand Down
137 changes: 85 additions & 52 deletions src/objectives/_objectives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module Objectives

export Objective
export NullObjective
export KnotPointObjective

using ..Constraints

Expand All @@ -13,23 +14,6 @@ using ForwardDiff
using TestItems


# include("quantum_objective.jl")
# include("unitary_infidelity_objective.jl")
include("regularizers.jl")
include("minimum_time_objective.jl")
include("terminal_loss.jl")
# include("unitary_robustness_objective.jl")
# include("density_operator_objectives.jl")

# TODO:
# - [ ] Do not reference the Z object in the objective (components only / remove "name")

"""
sparse_to_moi(A::SparseMatrixCSC)

Converts a sparse matrix to tuple of vector of nonzero indices and vector of nonzero values
"""

# ----------------------------------------------------------------------------- #
# Objective #
# ----------------------------------------------------------------------------- #
Expand All @@ -38,15 +22,12 @@ Converts a sparse matrix to tuple of vector of nonzero indices and vector of non
Objective

A structure for defining objective functions.

The `terms` field contains all the arguments needed to construct the objective function.


Fields:
`L`: the objective function
`∇L`: the gradient of the objective function
`∂²L`: the Hessian of the objective function
`∂²L_structure`: the structure of the Hessian of the objective function
`terms`: a vector of dictionaries containing the terms of the objective function
"""
struct Objective
L::Function
Expand All @@ -63,45 +44,97 @@ function Base.:+(obj1::Objective, obj2::Objective)
return Objective(L, ∇L, ∂²L, ∂²L_structure)
end

function Base.:*(num::Real, obj::Objective)
L = (Z⃗) -> num * obj.L(Z⃗)
∇L = (Z⃗) -> num * obj.∇L(Z⃗)
∂²L = (Z⃗) -> num * obj.∂²L(Z⃗)
return Objective(L, ∇L, ∂²L, obj.∂²L_structure)
end

Base.:*(obj::Objective, num::Real) = num * obj

# TODO: Unnecessary?
# Base.:+(obj::Objective, ::Nothing) = obj
# Base.:+(obj::Objective) = obj

# function Objective(terms::AbstractVector{<:Dict})
# return +(Objective.(terms)...)
# end

# function Base.:*(num::Real, obj::Objective)
# L = (Z⃗, Z) -> num * obj.L(Z⃗, Z)
# ∇L = (Z⃗, Z) -> num * obj.∇L(Z⃗, Z)
# if isnothing(obj.∂²L)
# ∂²L = nothing
# ∂²L_structure = nothing
# else
# ∂²L = (Z⃗, Z) -> num * obj.∂²L(Z⃗, Z)
# ∂²L_structure = obj.∂²L_structure
# end
# return Objective(L, ∇L, ∂²L, ∂²L_structure, obj.terms)
# end

# Base.:*(obj::Objective, num::Real) = num * obj

# function Objective(term::Dict)
# return eval(term[:type])(; delete!(term, :type)...)
# end
# ----------------------------------------------------------------------------- #
# Null objective
# ----------------------------------------------------------------------------- #

function NullObjective(Z::NamedTrajectory)
L(::AbstractVector{<:Real}) = 0.0
∇L(::AbstractVector{R}) where R<:Real = zeros(R, Z.dim * Z.T + Z.global_dim)
∂²L_structure() = []
∂²L(::AbstractVector{R}) where R<:Real = R[]
return Objective(L, ∇L, ∂²L, ∂²L_structure)
end

# ----------------------------------------------------------------------------- #
# Null objective #
# Knot Point objective
# ----------------------------------------------------------------------------- #

function NullObjective()
L(Z⃗::AbstractVector{R}) where R<:Real = 0.0
∇L(Z⃗::AbstractVector{R}) where R<:Real = zeros(R, Z.dim * Z.T + Z.global_dim)
∂²L_structure(Z::NamedTrajectory) = []
function ∂²L(Z⃗::AbstractVector{<:Real}; return_moi_vals=true)
n = Z.dim * Z.T + Z.global_dim
return return_moi_vals ? [] : spzeros(n, n)
function KnotPointObjective(
ℓ::Function,
name::Symbol,
traj::NamedTrajectory;
Qs::AbstractVector{Float64}=ones(traj.T),
times::AbstractVector{Int}=1:traj.T,
)
@assert length(Qs) == length(times) "Qs must have the same length as times"

Z_dim = traj.dim * traj.T + traj.global_dim
x_slices = [slice(t, traj.components[name], traj.dim) for t in times]

function L(Z⃗::AbstractVector{<:Real})
loss = 0.0
for (i, x_slice) in enumerate(x_slices)
x = Z⃗[x_slice]
loss += Qs[i] * ℓ(x)
end
return loss
end
return Objective(L, ∇L, ∂²L, ∂²L_structure)

@views function ∇L(Z⃗::AbstractVector{<:Real})
∇ = zeros(Z_dim)
for (i, x_slice) in enumerate(x_slices)
x = Z⃗[x_slice]
∇[x_slice] = ForwardDiff.gradient(x -> Qs[i] * ℓ(x), x)
end
return ∇
end

function ∂²L_structure()
structure = spzeros(Z_dim, Z_dim)
for x_slice in x_slices
structure[x_slice, x_slice] .= 1.0
end
structure_pairs = collect(zip(findnz(structure)[1:2]...))
return structure_pairs
end

@views function ∂²L(Z⃗::AbstractVector{<:Real})
∂²L_values = zeros(length(∂²L_structure()))
for (i, x_slice) in enumerate(x_slices)
∂²ℓ = ForwardDiff.hessian(x -> Qs[i] * ℓ(x), Z⃗[x_slice])
∂²ℓ_length = length(∂²ℓ[:])
∂²L_values[(i - 1) * ∂²ℓ_length + 1:i * ∂²ℓ_length] = ∂²ℓ[:]
end
return ∂²L_values
end

return Objective(L, ∇L, ∂²L, ∂²L_structure)
end

# ----------------------------------------------------------------------------- #
# Additional objectives
# ----------------------------------------------------------------------------- #

include("minimum_time_objective.jl")
include("regularizers.jl")
include("terminal_loss.jl")

# =========================================================================== #

# TODO: Testing (see constraints)

end
20 changes: 8 additions & 12 deletions src/objectives/minimum_time_objective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,42 +4,38 @@ export MinimumTimeObjective
"""
MinimumTimeObjective

A type of objective that counts the time taken to complete a task.

Fields:
`D`: a scaling factor
`Δt_indices`: the indices of the time steps
`eval_hessian`: whether to evaluate the Hessian
A type of objective that counts the time taken to complete a task. `D` is a scaling factor.

"""
function MinimumTimeObjective(
traj::NamedTrajectory;
D::Float64=1.0,
timesteps_all_equal::Bool=false,
)
@assert traj.timestep isa Symbol "Trajectory timestep must be a symbol, i.e. free time"
@assert traj.timestep isa Symbol "Trajectory timestep must be a symbol (free time)"

if !timesteps_all_equal
Δt_indices = [index(t, traj.components[traj.timestep][1], traj.dim) for t = 1:traj.T-1]

L = Z⃗::AbstractVector -> D * sum(Z⃗[Δt_indices])
L = Z⃗::AbstractVector{<:Real} -> D * sum(Z⃗[Δt_indices])

∇L = (Z⃗::AbstractVector) -> begin
∇L = (Z⃗::AbstractVector{<:Real}) -> begin
∇ = zeros(eltype(Z⃗), length(Z⃗))
∇[Δt_indices] .= D
return ∇
end
else
Δt_T_index = index(traj.T, traj.components[traj.timestep][1], traj.dim)
L = Z⃗::AbstractVector -> D * Z⃗[Δt_T_index]
∇L = (Z⃗::AbstractVector) -> begin
L = Z⃗::AbstractVector{<:Real} -> D * Z⃗[Δt_T_index]
∇L = (Z⃗::AbstractVector{<:Real}) -> begin
∇ = zeros(eltype(Z⃗), length(Z⃗))
∇[Δt_T_index] = D
return ∇
end

end

∂²L = Z⃗ -> []
∂²L = Z⃗::AbstractVector{<:Real} -> []
∂²L_structure = () -> []
return Objective(L, ∇L, ∂²L, ∂²L_structure)
end
Loading
Loading