From d76fa7980f888875fd01d075a99913f734b84a4d Mon Sep 17 00:00:00 2001 From: Michael Goerz Date: Wed, 17 Jul 2024 07:22:52 -0400 Subject: [PATCH 1/5] Typos --- docs/src/benchmarks.md | 4 ++++ src/cheby.jl | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/src/benchmarks.md b/docs/src/benchmarks.md index 8ce55bf1..408f1156 100644 --- a/docs/src/benchmarks.md +++ b/docs/src/benchmarks.md @@ -6,3 +6,7 @@ Pages=[ ] Depth = 2 ``` + +```@raw html + +``` diff --git a/src/cheby.jl b/src/cheby.jl index d6f626a8..840cb64c 100644 --- a/src/cheby.jl +++ b/src/cheby.jl @@ -13,7 +13,7 @@ using TimerOutputs: @timeit_debug, TimerOutput a::Vector{Float64} = cheby_coeffs(Δ, dt; limit=1e-12) ``` -return an array of coefficiencts larger than `limit`. +return an array of coefficients larger than `limit`. # Arguments @@ -141,7 +141,7 @@ cheby!(Ψ, H, dt, wrk; E_min=nothing, check_normalization=false) `E_min`, as long as the spectra radius `Δ` and the time step `dt` are the same as those used for the initialization of `wrk`. * `check_normalizataion`: perform checks that the H does not exceed the - spectral radius for which the the workspace was initialized. + spectral radius for which the workspace was initialized. The routine will not allocate any internal storage. This implementation requires `copyto!` `lmul!`, and `axpy!` to be implemented for `Ψ`, and the From 5023b66caee640367260a9481fb821c9f2a6d6af Mon Sep 17 00:00:00 2001 From: Michael Goerz Date: Fri, 19 Jul 2024 09:07:26 -0400 Subject: [PATCH 2/5] Properly define the in-place vs not-in-place interface We support both mutating and non-mutating propagators. Mutation is better for large Hilbert spaces. Non-mutation is better for small Hilbert spaces (`StaticArrays`!) or when trying to use automatic differentiation. There are some subtleties in finding the correct abstraction. It is not as simple as using the built-in `ismutable` for states or operators and making decisions based on that: Anytime we use custom structs, unless that struct is explicitly defined as `mutable`, it is considered immutable. However, we can still use in-place propagation, mutating the mutable *components* of that struct. Instead of overloading `ismutable`, we define the in-place or not-in-place interface explicitly via the required behavior guaranteed by the `check_state`, `check_generator`, and `check_operator` functions. A new `QuantumPropagators.Interfaces.supports_inplace` function is available to check whether a given `state` or `operator` type is suitable for in-place operations. --- Project.toml | 3 +- ext/QuantumPropagatorsODEExt.jl | 5 +- ext/QuantumPropagatorsStaticArraysExt.jl | 9 + src/QuantumPropagators.jl | 35 +-- src/cheby_propagator.jl | 11 +- src/exp_propagator.jl | 14 +- src/expprop.jl | 1 - src/generators.jl | 6 +- src/interfaces/generator.jl | 35 +-- src/interfaces/operator.jl | 81 ++++--- src/interfaces/propagator.jl | 30 +-- src/interfaces/state.jl | 225 +++++++++++------- src/interfaces/supports_inplace.jl | 41 ++++ src/newton_propagator.jl | 12 +- src/ode_function.jl | 9 +- src/propagate.jl | 25 +- src/propagator.jl | 8 +- test/test_invalid_interfaces.jl | 289 ++++++++++++++--------- 18 files changed, 504 insertions(+), 335 deletions(-) create mode 100644 ext/QuantumPropagatorsStaticArraysExt.jl create mode 100644 src/interfaces/supports_inplace.jl diff --git a/Project.toml b/Project.toml index 54a61de0..dab345e8 100644 --- a/Project.toml +++ b/Project.toml @@ -11,17 +11,18 @@ ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" [weakdeps] OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [extensions] QuantumPropagatorsODEExt = "OrdinaryDiffEq" QuantumPropagatorsRecursiveArrayToolsExt = "RecursiveArrayTools" +QuantumPropagatorsStaticArraysExt = "StaticArrays" [compat] OffsetArrays = "1" diff --git a/ext/QuantumPropagatorsODEExt.jl b/ext/QuantumPropagatorsODEExt.jl index 0d1328d7..0567661c 100644 --- a/ext/QuantumPropagatorsODEExt.jl +++ b/ext/QuantumPropagatorsODEExt.jl @@ -12,6 +12,7 @@ using QuantumPropagators: ode_function using QuantumPropagators.Controls: get_controls, get_parameters, evaluate, evaluate!, discretize_on_midpoints +using QuantumPropagators.Interfaces: supports_inplace import QuantumPropagators: init_prop, reinit_prop!, prop_step!, set_state!, set_t! @@ -24,7 +25,7 @@ ode_propagator = init_prop( generator, tlist; method=OrdinaryDiffEq, # or: `method=DifferentialEquations` - inplace=true, + inplace=QuantumPropagators.Interfaces.supports_inplace(state), backward=false, verbose=false, parameters=nothing, @@ -92,7 +93,7 @@ function init_prop( generator, tlist, method::Val{:OrdinaryDiffEq}; - inplace=true, + inplace=supports_inplace(state), backward=false, verbose=false, parameters=nothing, diff --git a/ext/QuantumPropagatorsStaticArraysExt.jl b/ext/QuantumPropagatorsStaticArraysExt.jl new file mode 100644 index 00000000..a51da4cb --- /dev/null +++ b/ext/QuantumPropagatorsStaticArraysExt.jl @@ -0,0 +1,9 @@ +module QuantumPropagatorsStaticArraysExt + +import QuantumPropagators.Interfaces: supports_inplace +using StaticArrays: SArray, MArray + +supports_inplace(::SArray) = false +supports_inplace(::MArray) = true + +end diff --git a/src/QuantumPropagators.jl b/src/QuantumPropagators.jl index 08d94dcb..caca83fe 100644 --- a/src/QuantumPropagators.jl +++ b/src/QuantumPropagators.jl @@ -34,30 +34,33 @@ include("propagator.jl") export init_prop, reinit_prop!, prop_step! # not exported: set_t!, set_state! -include("pwc_utils.jl") -include("cheby_propagator.jl") -include("newton_propagator.jl") -include("exp_propagator.jl") - -include("ode_function.jl") - #! format: off module Interfaces + export supports_inplace export check_operator, check_state, check_tlist, check_amplitude export check_control, check_generator, check_propagator export check_parameterized_function, check_parameterized - include(joinpath("interfaces", "utils.jl")) - include(joinpath("interfaces", "state.jl")) - include(joinpath("interfaces", "tlist.jl")) - include(joinpath("interfaces", "operator.jl")) - include(joinpath("interfaces", "amplitude.jl")) - include(joinpath("interfaces", "control.jl")) - include(joinpath("interfaces", "generator.jl")) - include(joinpath("interfaces", "propagator.jl")) - include(joinpath("interfaces", "parameterization.jl")) + include("interfaces/supports_inplace.jl") + include("interfaces/utils.jl") + include("interfaces/state.jl") + include("interfaces/tlist.jl") + include("interfaces/operator.jl") + include("interfaces/amplitude.jl") + include("interfaces/control.jl") + include("interfaces/generator.jl") + include("interfaces/propagator.jl") + include("interfaces/parameterization.jl") end #! format: on +include("pwc_utils.jl") +include("cheby_propagator.jl") +include("newton_propagator.jl") +include("exp_propagator.jl") + +include("ode_function.jl") + + include("timings.jl") # high-level interface diff --git a/src/cheby_propagator.jl b/src/cheby_propagator.jl index 159a7077..e912e1b4 100644 --- a/src/cheby_propagator.jl +++ b/src/cheby_propagator.jl @@ -1,4 +1,5 @@ using .Controls: get_controls, evaluate, discretize +using .Interfaces: supports_inplace using TimerOutputs: reset_timer!, @timeit_debug, TimerOutput """Propagator for Chebychev propagation (`method=QuantumPropagators.Cheby`). @@ -38,7 +39,7 @@ cheby_propagator = init_prop( generator, tlist; method=Cheby, - inplace=true, + inplace=QuantumPropagators.Interfaces.supports_inplace(state), backward=false, verbose=false, parameters=nothing, @@ -88,7 +89,7 @@ function init_prop( generator, tlist, method::Val{:Cheby}; - inplace=true, + inplace=supports_inplace(state), backward=false, verbose=false, parameters=nothing, @@ -356,7 +357,11 @@ function prop_step!(propagator::ChebyPropagator) tlist = getfield(propagator, :tlist) (0 < n < length(tlist)) || return nothing if propagator.inplace - H = _pwc_set_genop!(propagator, n) + if supports_inplace(H) + H = _pwc_set_genop!(propagator, n) + else + H = _pwc_get_genop(propagator, n) + end Cheby.cheby!( Ψ, H, diff --git a/src/exp_propagator.jl b/src/exp_propagator.jl index 83863c5f..6737a6b3 100644 --- a/src/exp_propagator.jl +++ b/src/exp_propagator.jl @@ -1,4 +1,5 @@ using .Controls: get_controls +using .Interfaces: supports_inplace using TimerOutputs: reset_timer!, @timeit_debug, TimerOutput """Propagator for propagation via direct exponentiation @@ -46,7 +47,7 @@ exp_propagator = init_prop( generator, tlist; method=ExpProp, - inplace=true, + inplace=QuantumPropagators.Interfaces.supports_inplace(state), backward=false, verbose=false, parameters=nothing, @@ -82,7 +83,7 @@ function init_prop( generator, tlist, method::Val{:ExpProp}; - inplace=true, + inplace=supports_inplace(state), backward=false, verbose=false, parameters=nothing, @@ -142,6 +143,7 @@ init_prop(state, generator, tlist, method::Val{:expprop}; kwargs...) = function prop_step!(propagator::ExpPropagator) @timeit_debug propagator.timing_data "prop_step!" begin + H = propagator.genop n = propagator.n tlist = getfield(propagator, :tlist) (0 < n < length(tlist)) || return nothing @@ -151,8 +153,12 @@ function prop_step!(propagator::ExpPropagator) end Ψ = convert(propagator.convert_state, propagator.state) if propagator.inplace - _pwc_set_genop!(propagator, n) - H = convert(propagator.convert_operator, propagator.genop) + if supports_inplace(propagator.genop) + _pwc_set_genop!(propagator, n) + H = convert(propagator.convert_operator, propagator.genop) + else + H = convert(propagator.convert_operator, _pwc_get_genop(propagator, n)) + end ExpProp.expprop!(Ψ, H, dt, propagator.wrk; func=propagator.func) if Ψ ≢ propagator.state # `convert` of Ψ may have been a no-op copyto!(propagator.state, convert(typeof(propagator.state), Ψ)) diff --git a/src/expprop.jl b/src/expprop.jl index b6688e4e..a5955145 100644 --- a/src/expprop.jl +++ b/src/expprop.jl @@ -4,7 +4,6 @@ module ExpProp export ExpPropWrk, expprop! using LinearAlgebra -import StaticArrays using TimerOutputs: @timeit_debug, TimerOutput diff --git a/src/generators.jl b/src/generators.jl index 260e789b..b76fa716 100644 --- a/src/generators.jl +++ b/src/generators.jl @@ -102,7 +102,11 @@ constant `Number`. If the number of coefficients is less than the number of operators, the first `ops` are considered to have ``c_l = 1``. An `Operator` object would generally not be instantiated directly, but be -obtained from a (@ref) via [`evaluate`](@ref). +obtained from a [`Generator`](@ref) via [`evaluate`](@ref). + +The ``Ĥ_l`` in the sum are considered immutable. This implies that an +`Operator` can be updated in-place with [`evaluate!`](@ref) by only changing +the `coeffs`. """ struct Operator{OT,CT<:Number} diff --git a/src/interfaces/generator.jl b/src/interfaces/generator.jl index 0c7cfad0..845fe375 100644 --- a/src/interfaces/generator.jl +++ b/src/interfaces/generator.jl @@ -6,8 +6,6 @@ using ..Generators: Generator ```julia @test check_generator( generator; state, tlist, - for_mutable_operator=true, for_immutable_operator=true, - for_mutable_state=true, for_immutable_state=true, for_pwc=true, for_time_continuous=false, for_expval=true, for_parameterization=false, atol=1e-14, quiet=false) @@ -26,10 +24,10 @@ verifies the given `generator`: If `for_pwc` (default): -* [`evaluate(generator, tlist, n)`](@ref evaluate) must return a valid +* [`op = evaluate(generator, tlist, n)`](@ref evaluate) must return a valid operator ([`check_operator`](@ref)), with forwarded keyword arguments (including `for_expval`) -* If `for_mutable_operator`, +* If `QuantumPropagators.Interfaces.supports_inplace(op)` is `true`, [`evaluate!(op, generator, tlist, n)`](@ref evaluate!) must be defined If `for_time_continuous`: @@ -37,8 +35,9 @@ If `for_time_continuous`: * [`evaluate(generator, t)`](@ref evaluate) must return a valid operator ([`check_operator`](@ref)), with forwarded keyword arguments (including `for_expval`) -* If `for_mutable_operator`, [`evaluate!(op, generator, t)`](@ref evaluate!) - must be defined + +* If `QuantumPropagators.Interfaces.supports_inplace(op)` is `true`, + [`evaluate!(op, generator, t)`](@ref evaluate!) must be defined If `for_parameterization` (may require the `RecursiveArrayTools` package to be loaded): @@ -55,10 +54,6 @@ function check_generator( generator; state, tlist, - for_mutable_operator=true, - for_immutable_operator=true, - for_mutable_state=true, - for_immutable_state=true, for_expval=true, for_pwc=true, for_time_continuous=false, @@ -69,7 +64,7 @@ function check_generator( _check_amplitudes=true # undocumented (internal use) ) - @assert check_state(state; for_mutable_state, for_immutable_state, atol, quiet=true) + @assert check_state(state; atol, quiet=true) @assert tlist isa Vector{Float64} @assert length(tlist) >= 2 @@ -170,8 +165,6 @@ function check_generator( op; state, tlist, - for_mutable_state, - for_immutable_state, for_expval, atol, quiet, @@ -189,14 +182,13 @@ function check_generator( success = false end + op = evaluate(generator, tlist, 1) + try - op = evaluate(generator, tlist, 1; vals_dict) if !check_operator( op; state, tlist, - for_mutable_state, - for_immutable_state, for_expval, atol, quiet, @@ -214,9 +206,8 @@ function check_generator( success = false end - if for_mutable_operator + if success && supports_inplace(op) try - op = evaluate(generator, tlist, 1) evaluate!(op, generator, tlist, length(tlist) - 1) catch exc quiet || @error( @@ -247,8 +238,6 @@ function check_generator( op; state, tlist, - for_mutable_state, - for_immutable_state, for_expval, atol, quiet, @@ -272,8 +261,6 @@ function check_generator( op; state, tlist, - for_mutable_state, - for_immutable_state, for_expval, atol, quiet, @@ -291,9 +278,9 @@ function check_generator( success = false end - if for_mutable_operator + op = evaluate(generator, tlist[begin]) + if success && supports_inplace(op) try - op = evaluate(generator, tlist[begin]) evaluate!(op, generator, tlist[end]) catch exc quiet || @error( diff --git a/src/interfaces/operator.jl b/src/interfaces/operator.jl index 52d0c1fb..b9637d5a 100644 --- a/src/interfaces/operator.jl +++ b/src/interfaces/operator.jl @@ -8,7 +8,6 @@ using ..Controls: get_controls, evaluate ```julia @test check_operator(op; state, tlist=[0.0, 1.0], - for_mutable_state=true, for_immutable_state=true, for_expval=true, atol=1e-14, quiet=false) ``` @@ -20,12 +19,12 @@ time-dependent dynamic generator. The specific requirements for `op` are: * `op` must not be time-dependent: [`evaluate(op, tlist, 1) ≡ op`](@ref evaluate) * `op` must not contain any controls: [`length(get_controls(op)) == 0`](@ref get_controls) - -If `for_immutable_state` (e.g., for use in propagators with `inplace=false`): - * `op * state` must be defined +* The [`QuantumPropagators.Interfaces.supports_inplace`](@ref) method must be + defined for `op`. If it returns `true`, it must be possible to evaluate a + generator in-place into the existing `op`. See [`check_generator`](@ref). -If `for_mutable_state` (e.g., for use in propagators with `inplace=true`): +If [`QuantumPropagators.Interfaces.supports_inplace(state)`](@ref): * The 3-argument `LinearAlgebra.mul!` must apply `op` to the given `state` * The 5-argument `LinearAlgebra.mul!` must apply `op` to the given `state` @@ -45,8 +44,6 @@ function check_operator( op; state, tlist=[0.0, 1.0], - for_mutable_state=true, - for_immutable_state=true, for_expval=true, atol=1e-14, quiet=false, @@ -60,9 +57,19 @@ function check_operator( Ψ = state - @assert check_state(state; for_mutable_state, for_immutable_state, atol, quiet=true) + @assert check_state(state; atol, quiet=true) @assert tlist isa Vector{Float64} + try + supports_inplace(op) + catch exc + quiet || @error( + "$(px)The `QuantumPropagators.Interfaces.supports_inplace` method must be defined for `op`.", + exception = (exc, catch_abbreviated_backtrace()) + ) + success = false + end + try H = evaluate(op, tlist, 1) if H ≢ op @@ -92,26 +99,24 @@ function check_operator( success = false end - if for_immutable_state - try - ϕ = op * Ψ - if !(ϕ isa typeof(Ψ)) - quiet || - @error "$(px)`op * state` must return an object of the same type as `state`, not $typeof(ϕ)" - success = false - end - catch exc - quiet || @error( - "$(px)`op * state` must be defined.", - exception = (exc, catch_abbreviated_backtrace()) - ) + try + ϕ = op * Ψ + if !(ϕ isa typeof(Ψ)) + quiet || + @error "$(px)`op * state` must return an object of the same type as `state`, not $typeof(ϕ)" success = false end - + catch exc + quiet || @error( + "$(px)`op * state` must be defined.", + exception = (exc, catch_abbreviated_backtrace()) + ) + success = false end - if for_mutable_state + + if supports_inplace(state) try ϕ = similar(Ψ) @@ -119,11 +124,9 @@ function check_operator( quiet || @error "$(px)`mul!(ϕ, op, state)` must return the resulting ϕ" success = false end - if for_immutable_state - if norm(ϕ - op * Ψ) > atol - quiet || @error "$(px)`mul!(ϕ, op, state)` must match `op * state`" - success = false - end + if norm(ϕ - op * Ψ) > atol + quiet || @error "$(px)`mul!(ϕ, op, state)` must match `op * state`" + success = false end catch exc quiet || @error( @@ -143,12 +146,10 @@ function check_operator( @error "$(px)`mul!(ϕ, op, state, α, β)` must return the resulting ϕ" success = false end - if for_immutable_state - if norm(ϕ - (0.5 * ϕ0 + 0.5 * (op * Ψ))) > atol - quiet || - @error "$(px)`mul!(ϕ, op, state, α, β)` must match β*ϕ + α*op*state" - success = false - end + if norm(ϕ - (0.5 * ϕ0 + 0.5 * (op * Ψ))) > atol + quiet || + @error "$(px)`mul!(ϕ, op, state, α, β)` must match β*ϕ + α*op*state" + success = false end catch exc quiet || @error( @@ -169,14 +170,12 @@ function check_operator( @error "$(px)`dot(state, op, state)` must return a number, not $typeof(val)" success = false end - if for_immutable_state - if abs(dot(Ψ, op, Ψ) - dot(Ψ, op * Ψ)) > atol - quiet || - @error "$(px)`dot(state, op, state)` must match `dot(state, op * state)`" - success = false - end + if abs(dot(Ψ, op, Ψ) - dot(Ψ, op * Ψ)) > atol + quiet || + @error "$(px)`dot(state, op, state)` must match `dot(state, op * state)`" + success = false end - if for_mutable_state + if supports_inplace(state) ϕ = similar(Ψ) mul!(ϕ, op, Ψ) if abs(dot(Ψ, op, Ψ) - dot(Ψ, ϕ)) > atol diff --git a/src/interfaces/propagator.jl b/src/interfaces/propagator.jl index a6432426..3f6a420a 100644 --- a/src/interfaces/propagator.jl +++ b/src/interfaces/propagator.jl @@ -19,9 +19,9 @@ initialized with [`init_prop`](@ref). * `propagator` must have the properties `state`, `tlist`, `t`, `parameters`, `backward`, and `inplace` -* `propagator.state` must be a valid state (see [`check_state`](@ref)), with - support for in-place operations (`for_mutable_state=true`) if - `propagator.inplace` is true. +* `propagator.state` must be a valid state (see [`check_state`](@ref)) +* If `propagator.inplace` is true, [`supports_inplace`](@ref) for + `propagator.state` must also be true * `propagator.tlist` must be monotonically increasing. * `propagator.t` must be the first or last element of `propagator.tlist`, depending on `propagator.backward` @@ -89,17 +89,18 @@ function check_propagator( end try - valid_state = check_state( - state; - for_mutable_state=inplace, - for_immutable_state=true, - atol, - _message_prefix="On `propagator.state`: " - ) + valid_state = check_state(state; atol, _message_prefix="On `propagator.state`: ") if !valid_state quiet || @error "$(px)`propagator.state` is not a valid state" success = false end + if inplace + if !supports_inplace(state) + quiet || + @error "$(px)`supports_inplace(propagator.state)` is false even for an in-place `propagator`" + success = false + end + end catch exc quiet || @error( "$(px)Cannot verify `propagator.state`.", @@ -151,13 +152,8 @@ function check_propagator( try Ψ₁ = prop_step!(propagator) - valid_state = check_state( - Ψ₁; - for_mutable_state=false, - for_immutable_state=false, - atol, - _message_prefix="On `Ψ₁=prop_step!(propagator)`: " - ) + valid_state = + check_state(Ψ₁; atol, _message_prefix="On `Ψ₁=prop_step!(propagator)`: ") if !valid_state quiet || @error "$(px)prop_step! must return a valid state until time grid is exhausted" diff --git a/src/interfaces/state.jl b/src/interfaces/state.jl index 11372300..87a22db0 100644 --- a/src/interfaces/state.jl +++ b/src/interfaces/state.jl @@ -6,39 +6,44 @@ using LinearAlgebra """Check that `state` is a valid element of a Hilbert space. ```julia -@test check_state(state; - for_immutable_state=true, for_mutable_state=true, - normalized=false, atol=1e-15, quiet=false) +@test check_state(state; normalized=false, atol=1e-15, quiet=false) ``` verifies the following requirements: * The inner product (`LinearAlgebra.dot`) of two states must return a Complex - number type. + number. * The `LinearAlgebra.norm` of `state` must be defined via the inner product. - This is the *definition* of a Hilbert space, a.k.a a complete inner product - space or more precisely a Banach space (normed vector space) where the - norm is induced by an inner product. + This is the *definition* of a Hilbert space, a.k.a a "complete inner product + space" or more precisely a "Banach space (normed vector space) where the + norm is induced by an inner product". +* The [`QuantumPropagators.Interfaces.supports_inplace](@ref) method must be + defined for `state` -If `for_immutable_state`: +Any `state` must support the following not-in-place operations: * `state + state` and `state - state` must be defined -* `copy(state)` must be defined +* `copy(state)` must be defined and return an object of the same type as + `state` * `c * state` for a scalar `c` must be defined * `norm(state + state)` must fulfill the triangle inequality +* `zero(state)` must be defined and produce a state with norm 0 * `0.0 * state` must produce a state with norm 0 * `copy(state) - state` must have norm 0 * `norm(state)` must have absolute homogeneity: `norm(s * state) = s * norm(state)` -If `for_mutable_state`: +If `supports_inplace(state)` is `true`, the `state` must also support the +following: -* `similar(state)` must be defined and return a valid state +* `similar(state)` must be defined and return a valid state of the same type a + `state` * `copyto!(other, state)` must be defined +* `fill!(state, c)` must be defined * `LinearAlgebra.lmul!(c, state)` for a scalar `c` must be defined * `LinearAlgebra.axpy!(c, state, other)` must be defined * `norm(state)` must fulfill the same general mathematical norm properties as - with `for_immutable_state`. + for the non-in-place norm. If `normalized` (not required by default): @@ -53,8 +58,6 @@ failed. """ function check_state( state; - for_immutable_state=true, - for_mutable_state=true, normalized=false, atol=1e-14, quiet=false, @@ -66,6 +69,17 @@ function check_state( px = _message_prefix success = true + inplace = false + + try + inplace = supports_inplace(state) + catch exc + quiet || @error( + "$(px)The `QuantumPropagators.Interfaces.supports_inplace` method must be defined for `state`.", + exception = (exc, catch_abbreviated_backtrace()) + ) + success = false + end try c = state ⋅ state @@ -98,78 +112,100 @@ function check_state( success = false end - if for_immutable_state + # not-in-place interface: - try - Δ = norm(state - state) - if Δ > atol - quiet || @error "`$(px)state - state` must have norm 0" - success = false - end - η = norm(state + state) - if η > (2 * norm(state) + atol) - quiet || - @error "`$(px)norm(state + state)` must fulfill the triangle inequality" - success = false - end - catch exc - quiet || @error( - "$(px)`state + state` and `state - state` must be defined.", - exception = (exc, catch_abbreviated_backtrace()) - ) + try + Δ = norm(state - state) + if Δ > atol + quiet || @error "`$(px)state - state` must have norm 0" success = false end - - try - ϕ = copy(state) - if norm(ϕ - state) > atol - quiet || @error "$(px)`copy(state) - state` must have norm 0" - success = false - end - catch exc - quiet || @error( - "$(px)copy(state) must be defined.", - exception = (exc, catch_abbreviated_backtrace()) - ) + η = norm(state + state) + if η > (2 * norm(state) + atol) + quiet || + @error "`$(px)norm(state + state)` must fulfill the triangle inequality" success = false end + catch exc + quiet || @error( + "$(px)`state + state` and `state - state` must be defined.", + exception = (exc, catch_abbreviated_backtrace()) + ) + success = false + end - try - ϕ = 0.5 * state - if abs(norm(ϕ) - 0.5 * norm(state)) > atol - quiet || - @error "$(px)`norm(state)` must have absolute homogeneity: `norm(s * state) = s * norm(state)`" - success = false - end - catch exc - quiet || @error( - "$(px)`c * state` for a scalar `c` must be defined.", - exception = (exc, catch_abbreviated_backtrace()) - ) + try + ϕ = copy(state) + if typeof(ϕ) != typeof(state) + quiet || + @error "$(px)`copy(state)::$(typeof(ϕ))` must have the same type as `state::$(typeof(state))`" success = false end + if norm(ϕ - state) > atol + quiet || @error "$(px)`copy(state) - state` must have norm 0" + success = false + end + catch exc + quiet || @error( + "$(px)copy(state) must be defined.", + exception = (exc, catch_abbreviated_backtrace()) + ) + success = false + end - try - if norm(0.0 * state) > atol - quiet || @error "$(px)`0.0 * state` must produce a state with norm 0" - success = false - end - catch exc - quiet || @error( - "$(px)`0.0 * state` must produce a state with norm 0.", - exception = (exc, catch_abbreviated_backtrace()) - ) + try + ϕ = 0.5 * state + if abs(norm(ϕ) - 0.5 * norm(state)) > atol + quiet || + @error "$(px)`norm(state)` must have absolute homogeneity: `norm(s * state) = s * norm(state)`" success = false end + catch exc + quiet || @error( + "$(px)`c * state` for a scalar `c` must be defined.", + exception = (exc, catch_abbreviated_backtrace()) + ) + success = false + end + try + state_zero = zero(state) + if norm(state_zero) > atol + quiet || @error("$(px)`zero(state)` must produce a state with norm 0.") + end + catch exc + quiet || @error( + "$(px)`zero(state)` must be defined and produce a state with norm 0.", + exception = (exc, catch_abbreviated_backtrace()) + ) + success = false + end + + try + if norm(0.0 * state) > atol + quiet || @error "$(px)`0.0 * state` must produce a state with norm 0" + success = false + end + catch exc + quiet || @error( + "$(px)`0.0 * state` must produce a state with norm 0.", + exception = (exc, catch_abbreviated_backtrace()) + ) + success = false end - if for_mutable_state + + if inplace has_similar = true similar_is_valid = true try ϕ = similar(state) + if typeof(ϕ) != typeof(state) + quiet || + @error "$(px)`similar(state)::$(typeof(ϕ))` must have the same type as `state::$(typeof(state))`" + success = false + end catch exc quiet || @error( "$(px)`similar(state)` must be defined.", @@ -189,8 +225,6 @@ function check_state( # the checks if !check_state( ϕ; - for_immutable_state, - for_mutable_state, normalized=false, atol, _check_similar=false, @@ -202,12 +236,10 @@ function check_state( success = false end end - if for_immutable_state - if norm(ϕ - state) > atol - quiet || - @error "$(px)`ϕ - state` must have norm 0, where `ϕ = similar(state); copyto!(ϕ, state)`" - success = false - end + if norm(ϕ - state) > atol + quiet || + @error "$(px)`ϕ - state` must have norm 0, where `ϕ = similar(state); copyto!(ϕ, state)`" + success = false end end catch exc @@ -218,17 +250,37 @@ function check_state( success = false end + try + if has_similar && similar_is_valid + ϕ = similar(state) + copyto!(ϕ, state) + state_zero = fill!(ϕ, 0.0) + if !isa(state_zero, typeof(ϕ)) + quiet || @error "$(px)`fill!(state, 0.0)` must return the filled state" + success = false + end + if norm(state_zero) > atol + quiet || @error "$(px)`fill!(state, 0.0)` must have norm 0" + success = false + end + end + catch exc + quiet || @error( + "$(px)`fill!(state, c)` must be defined.", + exception = (exc, catch_abbreviated_backtrace()) + ) + success = false + end + try if has_similar && similar_is_valid ϕ = similar(state) copyto!(ϕ, state) ϕ = lmul!(1im, ϕ) - if for_immutable_state - if norm(ϕ - 1im * state) > atol - quiet || - @error "$(px)`norm(state)` must have absolute homogeneity: `norm(s * state) = s * norm(state)`" - success = false - end + if norm(ϕ - 1im * state) > atol + quiet || + @error "$(px)`norm(state)` must have absolute homogeneity: `norm(s * state) = s * norm(state)`" + success = false end end catch exc @@ -263,12 +315,9 @@ function check_state( ϕ = similar(state) copyto!(ϕ, state) ϕ = axpy!(1im, state, ϕ) - if for_immutable_state - if norm(ϕ - (state + 1im * state)) > atol - quiet || - @error "$(px)`axpy!(a, state, ϕ)` must match `ϕ += a * state`" - success = false - end + if norm(ϕ - (state + 1im * state)) > atol + quiet || @error "$(px)`axpy!(a, state, ϕ)` must match `ϕ += a * state`" + success = false end end catch exc @@ -279,7 +328,7 @@ function check_state( success = false end - end + end # in-place interface if normalized try diff --git a/src/interfaces/supports_inplace.jl b/src/interfaces/supports_inplace.jl new file mode 100644 index 00000000..04c35e35 --- /dev/null +++ b/src/interfaces/supports_inplace.jl @@ -0,0 +1,41 @@ +import ..Operator +import ..ScaledOperator +import LinearAlgebra + +"""Indicate whether a given state or operator supports in-place operations + +```julia +supports_inplace(state) +``` + +Indicates that propagators can assume that the in-place requirements defined +in [`QuantumPropagators.Interfaces.check_state`](@ref) hold. States with +in-place support must also fulfill specific properties when interacting with +operators, see [`QuantumPropagators.Interfaces.check_operator`](@ref). + +```julia +supports_inplace(op) +``` + +Indicates that the operator can be evaluated in-place with [`evaluate!`](@ref), +see [`QuantumPropagators.Interfaces.check_generator`](@ref) + +Note that `supports_inplace` is not quite the same as +[`Base.ismutable`](@extref): When using [custom structs](@extref Julia +:label:`Mutable-Composite-Types`) for states or operators, even if those +structs are not defined as `mutable`, they may still define the in-place +interface (typically because their *components* are mutable). +""" +supports_inplace(state::Vector{ComplexF64}) = true +supports_inplace(state::AbstractVector) = ismutable(state) # fallback +# The fallback doesn't actually guarantee that the required interface implied +# by `supports_inplace` is fulfilled, but it's a reasonable expectation to +# have, and the `check_state` function will test it. + +supports_inplace(op::Matrix) = true +supports_inplace(op::Operator) = true +supports_inplace(op::LinearAlgebra.Diagonal) = true +supports_inplace(op::ScaledOperator) = supports_inplace(op.operator) +supports_inplace(op::AbstractMatrix) = ismutable(op) # fallback + +# Note: methods for StaticArrays are defined in package extension diff --git a/src/newton_propagator.jl b/src/newton_propagator.jl index 43cf6234..fadaf2f5 100644 --- a/src/newton_propagator.jl +++ b/src/newton_propagator.jl @@ -1,4 +1,5 @@ using .Controls: get_controls +using .Interfaces: supports_inplace using TimerOutputs: reset_timer!, @timeit_debug """Propagator for Newton propagation (`method=QuantumPropagators.Newton`). @@ -36,7 +37,7 @@ newton_propagator = init_prop( generator, tlist; method=Newton, - inplace=true, + inplace=QuantumPropagators.Interfaces.supports_inplace(state), backward=false, verbose=false, parameters=nothing, @@ -63,7 +64,7 @@ function init_prop( generator, tlist, method::Val{:Newton}; - inplace=true, + inplace=supports_inplace(state), backward=false, verbose=false, parameters=nothing, @@ -119,6 +120,7 @@ init_prop(state, generator, tlist, method::Val{:newton}; kwargs...) = function prop_step!(propagator::NewtonPropagator) @timeit_debug propagator.timing_data "prop_step!" begin Ψ = propagator.state + H = propagator.genop n = propagator.n # index of interval we're going to propagate tlist = getfield(propagator, :tlist) (0 < n < length(tlist)) || return nothing @@ -127,7 +129,11 @@ function prop_step!(propagator::NewtonPropagator) dt = -dt end if propagator.inplace - H = _pwc_set_genop!(propagator, n) + if supports_inplace(H) + H = _pwc_set_genop!(propagator, n) + else + H = _pwc_get_genop(propagator, n) + end Newton.newton!( Ψ, H, diff --git a/src/ode_function.jl b/src/ode_function.jl index cf12a2b3..7c7459b0 100644 --- a/src/ode_function.jl +++ b/src/ode_function.jl @@ -1,3 +1,4 @@ +using QuantumPropagators.Interfaces: supports_inplace using TimerOutputs: @timeit_debug, TimerOutput @doc raw""" @@ -65,10 +66,14 @@ end function (f::QuantumODEFunction)(du, u, p, t) @timeit_debug f.timing_data "operator evaluation" begin - evaluate!(f.operator, f.generator, t; vals_dict=p) + if supports_inplace(f.operator) + H = evaluate!(f.operator, f.generator, t; vals_dict=p) + else + H = evaluate(f.generator, t; vals_dict=p) + end end @timeit_debug f.timing_data "matrix-vector product" begin - return mul!(du, f.operator, u, f.c, false) + return mul!(du, H, u, f.c, false) end end diff --git a/src/propagate.jl b/src/propagate.jl index 57d8db8d..56de919d 100644 --- a/src/propagate.jl +++ b/src/propagate.jl @@ -4,7 +4,7 @@ using ProgressMeter using .Storage: init_storage, write_to_storage! import .Storage: map_observables -using .Interfaces: check_state, check_generator, check_tlist +using .Interfaces: check_state, check_generator, check_tlist, supports_inplace # Default "observables" for just storing the propagated state. We want to keep @@ -31,7 +31,7 @@ state = propagate( method, # mandatory keyword argument check=true, backward=false, - inplace=true, + inplace=QuantumPropagators.Interfaces.supports_inplace(state), verbose=false, piecewise=nothing, pwc=nothing, @@ -84,7 +84,9 @@ routine performs the following three steps: * `backward`: If `true`, propagate backward in time * `inplace`: If `true`, propagate using in-place operations. If `false`, avoid in-place operations. Not all propagation methods - support both in-place and not-in-place propagation. + support both in-place and not-in-place propagation. Note that `inplace=true` + requires that [`QuantumPropagators.Interfaces.supports_inplace`](@ref) for + `state` is `true`. * `piecewise`: If given as a boolean, ensure that the internal `propagator` is an instance of [`PiecewisePropagator`](@ref), cf. [`init_prop`](@ref). * `pwc`: If given a a boolean, do a piecewise constant propagation where the @@ -172,11 +174,8 @@ function propagate( show_progress=false, observables=_StoreState(), callback=nothing, - inplace=true, # cf. default of init_prop + inplace=supports_inplace(state), for_expval=true, # undocumented - for_immutable_state=true, # undocumented - for_mutable_operator=inplace, # undocumented - for_mutable_state=inplace, # undocumented kwargs... ) atol = get(kwargs, :atol, 1e-14) # for checks @@ -195,8 +194,6 @@ function propagate( end valid_state = check_state( state; - for_immutable_state, - for_mutable_state, atol, quiet, _message_prefix="On initial `state` in `propagate`: " @@ -204,13 +201,17 @@ function propagate( if !valid_state error("The initial state in `propagate` does not pass `check_state`") end + if inplace + if !supports_inplace(state) + error( + "The propagator is in-place, but the given state does not support in-place operations" + ) + end + end valid_generator = check_generator( generator; state=state, tlist=tlist, - for_immutable_state, - for_mutable_operator, - for_mutable_state, for_expval, atol, quiet, diff --git a/src/propagator.jl b/src/propagator.jl index ec197593..18c4bba6 100644 --- a/src/propagator.jl +++ b/src/propagator.jl @@ -133,7 +133,7 @@ propagator = init_prop( state, generator, tlist; method, # mandatory keyword argument backward=false, - inplace=true, + inplace=QuantumPropagators.Interfaces.supports_inplace(state), piecewise=nothing, pwc=nothing, kwargs... @@ -172,8 +172,8 @@ time grid `tlist` under the time-dependent generator (Hamiltonian/Liouvillian) call to [`prop_step!`](@ref) changes the reference for `propagator.state`, and the propagation will not use any in-place operations. Not all propagation methods may support both in-place and not-in-place propagation. In-place - propagation is generally more efficient but may not be compatible, e.g., with - automatic differentiation. + propagation is generally more efficient for larger Hilbert space dimensions, + but may not be compatible, e.g., with automatic differentiation. * `piecewise`: If given as a boolean, `true` enforces that the resulting propagator is a [`PiecewisePropagator`](@ref), and `false` enforces that it not a [`PiecewisePropagator`](@ref). For the default `piecewise=nothing`, @@ -211,7 +211,7 @@ function init_prop( tlist; method, # mandatory keyword argument backward=false, - inplace=true, + inplace=Interfaces.supports_inplace(state), verbose=false, piecewise=nothing, pwc=nothing, diff --git a/test/test_invalid_interfaces.jl b/test/test_invalid_interfaces.jl index 4ba8addd..3a98f77e 100644 --- a/test/test_invalid_interfaces.jl +++ b/test/test_invalid_interfaces.jl @@ -1,10 +1,11 @@ using Test using Logging: with_logger -using QuantumControlTestUtils: QuantumTestLogger +using IOCapture: IOCapture using QuantumControlTestUtils.RandomObjects: random_dynamic_generator, random_state_vector using StableRNGs: StableRNG -using QuantumPropagators: init_prop +using QuantumPropagators: QuantumPropagators, init_prop using LinearAlgebra +import QuantumPropagators.Interfaces: supports_inplace using QuantumPropagators.Interfaces: check_amplitude, check_control, @@ -21,16 +22,16 @@ using QuantumPropagators.Interfaces: tlist = collect(range(0, 10, length=101)) ampl = InvalidAmplitude() - test_logger = QuantumTestLogger() - with_logger(test_logger) do - @test check_amplitude(ampl; tlist) ≡ false + captured = IOCapture.capture() do + check_amplitude(ampl; tlist) end + @test captured.value ≡ false - @test "get_controls(ampl)` must be defined" ∈ test_logger - @test "all controls in `ampl` must pass `check_control`" ∈ test_logger - @test "`substitute(ampl, replacements)` must be defined" ∈ test_logger - @test "`evaluate(ampl, tlist, 1)` must return a Number" ∈ test_logger - @test "`evaluate(ampl, tlist, n; vals_dict)` must be defined" ∈ test_logger + @test contains(captured.output, "get_controls(ampl)` must be defined") + @test contains(captured.output, "all controls in `ampl` must pass `check_control`") + @test contains(captured.output, "`substitute(ampl, replacements)` must be defined") + @test contains(captured.output, "`evaluate(ampl, tlist, 1)` must return a Number") + @test contains(captured.output, "`evaluate(ampl, tlist, n; vals_dict)` must be defined") end @@ -42,23 +43,30 @@ end tlist = collect(range(0, 10, length=101)) control = InvalidControl() - test_logger = QuantumTestLogger() - with_logger(test_logger) do - @test check_control(control; tlist) ≡ false + captured = IOCapture.capture() do + check_control(control; tlist) end - - @test "`discretize(control, tlist)` must be defined" ∈ test_logger - @test "`discretize_on_midpoints(control, tlist)` must be defined" ∈ test_logger + @test captured.value ≡ false + @test contains(captured.output, "`discretize(control, tlist)` must be defined") + @test contains( + captured.output, + "`discretize_on_midpoints(control, tlist)` must be defined" + ) tlist = [0.0, 1.0, 2.0, 3.0] control = [0.0, NaN, Inf, 0.0] - test_logger = QuantumTestLogger() - with_logger(test_logger) do - @test check_control(control; tlist) ≡ false + captured = IOCapture.capture() do + check_control(control; tlist) end - @test "all values in `discretize(control, tlist)` must be finite" ∈ test_logger - @test "all values in `discretize_on_midpoints(control, tlist)` must be finite" ∈ - test_logger + @test captured.value ≡ false + @test contains( + captured.output, + "all values in `discretize(control, tlist)` must be finite" + ) + @test contains( + captured.output, + "all values in `discretize_on_midpoints(control, tlist)` must be finite" + ) end @@ -71,16 +79,25 @@ end tlist = collect(range(0, 10, length=101)) operator = InvalidOperator() - test_logger = QuantumTestLogger() - with_logger(test_logger) do - @test check_operator(operator; state, tlist) ≡ false + captured = IOCapture.capture() do + check_operator(operator; state, tlist) end - - @test "op must not contain any controls" ∈ test_logger - @test "`op * state` must be defined" ∈ test_logger - @test "The 3-argument `mul!` must apply `op` to the given `state`" ∈ test_logger - @test "The 5-argument `mul!` must apply `op` to the given `state`" ∈ test_logger - @test "`dot(state, op, state)` must return return a number" ∈ test_logger + @test captured.value ≡ false + @test contains( + captured.output, + "The `QuantumPropagators.Interfaces.supports_inplace` method must be defined for `op`" + ) + @test contains(captured.output, "op must not contain any controls") + @test contains(captured.output, "`op * state` must be defined") + @test contains( + captured.output, + "The 3-argument `mul!` must apply `op` to the given `state`" + ) + @test contains( + captured.output, + "The 5-argument `mul!` must apply `op` to the given `state`" + ) + @test contains(captured.output, "`dot(state, op, state)` must return return a number") end @@ -93,61 +110,82 @@ end tlist = collect(range(0, 10, length=101)) generator = InvalidGenerator() - test_logger = QuantumTestLogger() - with_logger(test_logger) do - @test check_generator(generator; state, tlist) ≡ false + captured = IOCapture.capture() do + check_generator(generator; state, tlist) end + @test captured.value ≡ false - @test "`get_controls(generator)` must be defined" ∈ test_logger - @test "`evaluate(generator, tlist, n)` must return an operator that passes `check_operator`" ∈ - test_logger - @test "`substitute(generator, replacements)` must be defined" ∈ test_logger + @test contains(captured.output, "`get_controls(generator)` must be defined") + @test contains( + captured.output, + "`evaluate(generator, tlist, n)` must return an operator that passes `check_operator`" + ) + @test contains(captured.output, "`substitute(generator, replacements)` must be defined") end @testset "Invalid state" begin + struct InvalidStaticState end + state = InvalidStaticState() + captured = IOCapture.capture() do + check_state(state) + end + @test captured.value ≡ false + @test contains( + captured.output, + "The `QuantumPropagators.Interfaces.supports_inplace` method must be defined for `state`" + ) + struct InvalidState end + QuantumPropagators.Interfaces.supports_inplace(::InvalidState) = true state = InvalidState() - test_logger = QuantumTestLogger() - with_logger(test_logger) do - @test check_state(state; normalized=true) ≡ false + captured = IOCapture.capture() do + check_state(state; normalized=true) end - @test "`similar(state)` must be defined" ∈ test_logger - @test "the inner product of two states must be a complex number" ∈ test_logger - @test "the norm of a state must be defined via the inner product" ∈ test_logger - @test "`state + state` and `state - state` must be defined" ∈ test_logger - @test "copy(state) must be defined" ∈ test_logger - @test "`c * state` for a scalar `c` must be defined" ∈ test_logger - @test "`0.0 * state` must produce a state with norm 0" ∈ test_logger - @test "`norm(state)` must be 1" ∈ test_logger - + @test captured.value ≡ false + @test contains(captured.output, "`similar(state)` must be defined") + @test contains( + captured.output, + "the inner product of two states must be a complex number" + ) + @test contains( + captured.output, + "the norm of a state must be defined via the inner product" + ) + @test contains(captured.output, "`state + state` and `state - state` must be defined") + @test contains(captured.output, "copy(state) must be defined") + @test contains(captured.output, "`c * state` for a scalar `c` must be defined") + @test contains(captured.output, "`0.0 * state` must produce a state with norm 0") + @test contains(captured.output, "`norm(state)` must be 1") struct InvalidState2 end + QuantumPropagators.Interfaces.supports_inplace(::InvalidState2) = true state = InvalidState2() Base.similar(::InvalidState2) = InvalidState2() - test_logger = QuantumTestLogger() - with_logger(test_logger) do - @test check_state(state; normalized=true) ≡ false + captured = IOCapture.capture() do + check_state(state; normalized=true) end - @test "`copyto!(other, state)` must be defined" ∈ test_logger + @test captured.value ≡ false + @test contains(captured.output, "`copyto!(other, state)` must be defined") struct InvalidState3 end + QuantumPropagators.Interfaces.supports_inplace(::InvalidState3) = true state = InvalidState3() Base.similar(::InvalidState3) = InvalidState3() Base.copyto!(a::InvalidState3, b::InvalidState3) = a - test_logger = QuantumTestLogger() - with_logger(test_logger) do - @test check_state(state; normalized=true) ≡ false + captured = IOCapture.capture() do + check_state(state; normalized=true) end - @test "`similar(state)` must return a valid state" ∈ test_logger - @test "On `similar(state)`: " ∈ test_logger - + @test captured.value ≡ false + @test contains(captured.output, "`similar(state)` must return a valid state") + @test contains(captured.output, "On `similar(state)`: ") struct InvalidState4 Ψ end + QuantumPropagators.Interfaces.supports_inplace(::InvalidState4) = true Base.similar(state::InvalidState4) = InvalidState4(similar(state.Ψ)) Base.copy(state::InvalidState4) = InvalidState4(copy(state.Ψ)) Base.copyto!(a::InvalidState4, b::InvalidState4) = copyto!(a.Ψ, b.Ψ) @@ -157,20 +195,20 @@ end Base.:-(a::InvalidState4, b::InvalidState4) = InvalidState4(a.Ψ - b.Ψ) Base.:*(α::Number, state::InvalidState4) = InvalidState4(α * state.Ψ) state = InvalidState4([1im, 0]) - test_logger = QuantumTestLogger() - with_logger(test_logger) do - @test check_state(state; normalized=true) ≡ false + captured = IOCapture.capture() do + check_state(state; normalized=true) end - @test "`lmul!(c, state)` for a scalar `c` must be defined" ∈ test_logger - @test "`lmul!(0.0, state)` must produce a state with norm 0" ∈ test_logger - @test "`axpy!(c, state, other)` must be defined" ∈ test_logger + @test captured.value ≡ false + @test contains(captured.output, "`lmul!(c, state)` for a scalar `c` must be defined") + @test contains(captured.output, "`lmul!(0.0, state)` must produce a state with norm 0") + @test contains(captured.output, "`axpy!(c, state, other)` must be defined") state = [1, 0, 0, 0] - test_logger = QuantumTestLogger() - with_logger(test_logger) do - @test check_state(state; normalized=true) ≡ false + captured = IOCapture.capture() do + check_state(state; normalized=true) end - @test "`state ⋅ state` must return a Complex number type" ∈ test_logger + @test captured.value ≡ false + @test contains(captured.output, "`state ⋅ state` must return a Complex number type") end @@ -183,12 +221,12 @@ end propagator = InvalidPropagatorEmpty() - test_logger = QuantumTestLogger() - with_logger(test_logger) do - @test check_propagator(propagator) ≡ false + captured = IOCapture.capture() do + check_propagator(propagator) end + @test captured.value ≡ false - @test "does not have the required properties" ∈ test_logger + @test contains(captured.output, "does not have the required properties") N = 10 tlist = collect(range(0, 100, length=1001)) @@ -198,15 +236,13 @@ end propagator = init_prop(Ψ, Ĥ, tlist; method=:invalid_propagator_no_methods) - test_logger = QuantumTestLogger() - with_logger(test_logger) do - @test check_propagator(propagator) ≡ false + captured = IOCapture.capture() do + check_propagator(propagator) end - @test "`prop_step!(propagator)` must be defined" ∈ test_logger - @test "Failed to run `prop_step!(propagator)`" ∈ test_logger - @test "`reinit_prop!` must be defined" ∈ test_logger - + @test contains(captured.output, "`prop_step!(propagator)` must be defined") + @test contains(captured.output, "Failed to run `prop_step!(propagator)`") + @test contains(captured.output, "`reinit_prop!` must be defined") propagator = init_prop( Ψ, @@ -216,15 +252,21 @@ end inplace=true, backward=false ) - test_logger = QuantumTestLogger() - with_logger(test_logger) do - @test check_propagator(propagator) ≡ false + captured = IOCapture.capture() do + check_propagator(propagator) end - @test "propagator.t ≠ propagator.tlist[begin]" ∈ test_logger - @test "`set_t!(propagator, t)` must set propagator.t" ∈ test_logger - @test "`prop_step!(propagator)` at final t=0.1 must return `nothing`" ∈ test_logger - @test "`propagator.parameters` must be a dict" ∈ test_logger - @test "`reinit_prop!(propagator, state)` must reset `propagator.t`" ∈ test_logger + @test captured.value ≡ false + @test contains(captured.output, "propagator.t ≠ propagator.tlist[begin]") + @test contains(captured.output, "`set_t!(propagator, t)` must set propagator.t") + @test contains( + captured.output, + "`prop_step!(propagator)` at final t=0.1 must return `nothing`" + ) + @test contains(captured.output, "`propagator.parameters` must be a dict") + @test contains( + captured.output, + "`reinit_prop!(propagator, state)` must reset `propagator.t`" + ) propagator = init_prop( Ψ, @@ -234,15 +276,19 @@ end inplace=false, backward=true ) - test_logger = QuantumTestLogger() - with_logger(test_logger) do - @test check_propagator(propagator) ≡ false + captured = IOCapture.capture() do + check_propagator(propagator) end - @test "propagator.t ≠ propagator.tlist[end]" ∈ test_logger - @test "For a not-in-place propagator, the state returned by `prop_step!` must be a new object" ∈ - test_logger - @test "`prop_step!` must advance `propagator.t` forward or backward one step on the time grid" ∈ - test_logger + @test captured.value ≡ false + @test contains(captured.output, "propagator.t ≠ propagator.tlist[end]") + @test contains( + captured.output, + "For a not-in-place propagator, the state returned by `prop_step!` must be a new object" + ) + @test contains( + captured.output, + "`prop_step!` must advance `propagator.t` forward or backward one step on the time grid" + ) propagator = init_prop( Ψ, @@ -252,15 +298,19 @@ end inplace=true, backward=false ) - test_logger = QuantumTestLogger() - with_logger(test_logger) do - @test check_propagator(propagator) ≡ false + captured = IOCapture.capture() do + check_propagator(propagator) end - @test "For an in-place propagator, the state returned by `prop_step!` must be the `propagator.state` object" ∈ - test_logger - @test "`set_state!(propagator, state)` for an in-place propagator must overwrite `propagator.state` in-place." ∈ - test_logger - @test "`reinit_prop!(propagator, state)` must be idempotent" ∈ test_logger + @test captured.value ≡ false + @test contains( + captured.output, + "For an in-place propagator, the state returned by `prop_step!` must be the `propagator.state` object" + ) + @test contains( + captured.output, + "`set_state!(propagator, state)` for an in-place propagator must overwrite `propagator.state` in-place." + ) + @test contains(captured.output, "`reinit_prop!(propagator, state)` must be idempotent") propagator = init_prop( Ψ, @@ -270,13 +320,18 @@ end inplace=false, backward=true ) - test_logger = QuantumTestLogger() - with_logger(test_logger) do - @test check_propagator(propagator) ≡ false + captured = IOCapture.capture() do + check_propagator(propagator) end - @test "For a not-in-place propagator, the state returned by `prop_step!` must be a new object" ∈ - test_logger - @test "`reinit_prop!` must be defined and re-initialize the propagator" ∈ test_logger + @test captured.value ≡ false + @test contains( + captured.output, + "For a not-in-place propagator, the state returned by `prop_step!` must be a new object" + ) + @test contains( + captured.output, + "`reinit_prop!` must be defined and re-initialize the propagator" + ) propagator = init_prop( Ψ, @@ -286,11 +341,13 @@ end inplace=true, backward=false ) - test_logger = QuantumTestLogger() - with_logger(test_logger) do - @test check_propagator(propagator) ≡ false + captured = IOCapture.capture() do + check_propagator(propagator) end - @test "prop_step! must return a valid state until time grid is exhausted" ∈ test_logger - + @test captured.value ≡ false + @test contains( + captured.output, + "prop_step! must return a valid state until time grid is exhausted" + ) end From c2eb3a6c16685f854235cf95d06d04fd84061898 Mon Sep 17 00:00:00 2001 From: Michael Goerz Date: Fri, 26 Jul 2024 10:39:37 -0400 Subject: [PATCH 3/5] Update CI --- .github/workflows/CI.yml | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index d3fc32eb..dac3e5f5 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -26,10 +26,10 @@ jobs: version: '1' steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} - - uses: julia-actions/cache@v1 + - uses: julia-actions/cache@v2 - name: "Instantiate test environment" run: | wget https://raw.githubusercontent.com/JuliaQuantumControl/JuliaQuantumControl/master/scripts/installorg.jl @@ -38,17 +38,17 @@ jobs: shell: julia --color=yes --project=test --code-coverage="@" --depwarn="yes" --check-bounds="yes" {0} run: | include(joinpath(pwd(), "test", "runtests.jl")) - - name: "Run uptream QuantumControlBase tests" + - name: "Run upstream QuantumControlBase tests" shell: julia --color=yes --project=test --code-coverage="@" --depwarn="yes" --check-bounds="yes" {0} run: | using Pkg Pkg.test("QuantumControlBase", coverage=true) - - name: "Run uptream Krotov tests" + - name: "Run upstream Krotov tests" shell: julia --color=yes --project=test --code-coverage="@" --depwarn="yes" --check-bounds="yes" {0} run: | using Pkg Pkg.test("Krotov", coverage=true) - - name: "Run uptream GRAPE tests" + - name: "Run upstream GRAPE tests" shell: julia --color=yes --project=test --code-coverage="@" --depwarn="yes" --check-bounds="yes" {0} run: | using Pkg @@ -56,18 +56,19 @@ jobs: - uses: julia-actions/julia-processcoverage@v1 - name: "Summarize coverage" run: julia --project=test -e 'using QuantumControlTestUtils; show_coverage();' - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v4 with: files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} docs: name: Documentation runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: '1' - - uses: julia-actions/cache@v1 + - uses: julia-actions/cache@v2 - run: | # Install Python dependencies set -x @@ -83,7 +84,7 @@ jobs: DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} - name: Zip the HTML documentation run: zip-folder --debug --auto-root --outfile "docs.zip" docs/build - - uses: actions/upload-artifact@v3 + - uses: actions/upload-artifact@v4 name: Upload documentation artifacts with: name: QuantumPropagators @@ -97,7 +98,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: '1' - name: Get typos settings From 3ef4a796b962cd9b678072a3681ef24892ab7393 Mon Sep 17 00:00:00 2001 From: Michael Goerz Date: Sat, 27 Jul 2024 09:40:19 -0400 Subject: [PATCH 4/5] Document in-place propagation --- docs/make.jl | 20 +++++++------------- docs/src/overview.md | 10 ++++++++++ test/Project.toml | 1 + 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index a00a80fb..ecf38bde 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -29,23 +29,17 @@ links = InterLinks( "https://github.com/KristofferC/TimerOutputs.jl", joinpath(@__DIR__, "src", "inventories", "TimerOutputs.toml"), ), - "QuantumControlBase" => "https://juliaquantumcontrol.github.io/QuantumControlBase.jl/$DEV_OR_STABLE", - "ComponentArrays" => ( - "https://jonniedie.github.io/ComponentArrays.jl/stable/", - "https://jonniedie.github.io/ComponentArrays.jl/stable/objects.inv", - joinpath(@__DIR__, "src", "inventories", "ComponentArrays.toml") - ), - "RecursiveArrayTools" => ( - "https://docs.sciml.ai/RecursiveArrayTools/stable/", - "https://docs.sciml.ai/RecursiveArrayTools/stable/objects.inv", - joinpath(@__DIR__, "src", "inventories", "RecursiveArrayTools.toml") - ), + "QuantumControl" => "https://juliaquantumcontrol.github.io/QuantumControl.jl/$DEV_OR_STABLE", + "StaticArrays" => "https://juliaarrays.github.io/StaticArrays.jl/stable/", + "ComponentArrays" => "https://jonniedie.github.io/ComponentArrays.jl/stable/", + "RecursiveArrayTools" => "https://docs.sciml.ai/RecursiveArrayTools/stable/", "qutip" => "https://qutip.readthedocs.io/en/qutip-5.0.x/", ) externals = ExternalFallbacks( - "QuantumControlBase.Trajectory" => "@extref QuantumControlBase :jl:type:`QuantumControlBase.Trajectory`", - "QuantumControlBase.ControlProblem" => "@extref QuantumControlBase :jl:type:`QuantumControlBase.ControlProblem`", + "Trajectory" => "@extref QuantumControl :jl:type:`QuantumControlBase.Trajectory`", + "QuantumControlBase.Trajectory" => "@extref QuantumControl :jl:type:`QuantumControlBase.Trajectory`", + "QuantumControlBase.ControlProblem" => "@extref QuantumControl :jl:type:`QuantumControlBase.ControlProblem`", ) println("Starting makedocs") diff --git a/docs/src/overview.md b/docs/src/overview.md index bb979a82..7c561410 100644 --- a/docs/src/overview.md +++ b/docs/src/overview.md @@ -175,6 +175,16 @@ print("t = $(round(propagator.t / π; digits=3))π\n") t = 0.5π ``` +## In-place propagation + +Most propagators support both an in-place and a not-in-place mode. These modes can be switched via the `inplace` parameter to [`propagate`](@ref)/[`init_prop`](@ref), which defaults to the value of [`QuantumPropagators.Interfaces.supports_inplace`](@ref) for the given initial `state`. When using in-place operations, propagators should minimize the allocation of memory and modify states using in-place linear algebra operations ([BLAS](@extref Julia `BLAS-functions`)). Otherwise, with `inplace=false`, mutating states or operators is avoided. See the in-place and not-in-place operations described in [`QuantumPropagators.Interfaces.check_state`](@ref) and [`QuantumPropagators.Interfaces.check_operator`](@ref). + +In-place operations can be dramatically more efficient for large Hilbert space dimensions. On the other hand, not-in-place operations can be more efficient for small Hilbert spaces, in particular when a [static vector](@extref StaticArrays `SVector`) can be used to represent the state. Moreover, frameworks for automatic differentiation such as [Zygote](https://fluxml.ai/Zygote.jl/stable/) do not support in-place operations. + +When using custom structs for states, operators, or generators, the struct itself not not need to be mutable (according to [`Base.ismutable`](@extref Julia)) in order to support `inplace=true`. It only must support the in-place operations defined in the [formal interface](@ref QuantumPropagatorsInterfacesAPI) and indicate that support by defining [`QuantumPropagators.Interfaces.supports_inplace`](@ref). +Typically, in-place operations on immutable custom structs involve mutating the mutable properties of that struct. + + ## Backward propagation When [`propagate`](@ref) or [`init_prop`](@ref) are called with `backward=true`, the propagation is initialized to run backward. The initial state is then defined at `propagator.t == tlist[end]` and each [`prop_step!`](@ref) moves to the previous point in `tlist`. The equation of motion is the Schrödinger or Liouville equation with a negative $dt$. For a Hermitian `generator`, doing a forward propagation followed by a backward propagation will recover the initial state. For a non-Hermitian `generator`, this no longer holds. Note that in optimal control methods such as GRAPE or Krotov's method, obtaining gradients involves a "backward propagation with the adjoint generator" (when the generator is non-Hermitian and adjoint/non-adjoint makes a difference). The [`propagate`](@ref) routine with `backward=true` will not automatically take this adjoint of the `generator`; instead, the adjoint generator must be passed explicitly. diff --git a/test/Project.toml b/test/Project.toml index 3f6b6007..c31eec4f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,6 +3,7 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Coverage = "a2441757-f6aa-5fb2-8edb-039e3f45d037" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DocInventories = "43dc2714-ed3b-44b5-b226-857eda1aa7de" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656" From 56896218fd6ada7f0bfc5a69baa97c08e066d2d5 Mon Sep 17 00:00:00 2001 From: Michael Goerz Date: Sat, 27 Jul 2024 10:00:46 -0400 Subject: [PATCH 5/5] Release 0.8.0 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index dab345e8..c7899184 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "QuantumPropagators" uuid = "7bf12567-5742-4b91-a078-644e72a65fc1" authors = ["Michael Goerz "] -version = "0.7.6+dev" +version = "0.8.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"