Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Move SDE default algorithm from DifferentialEquations.jl to Stochasti…
…cDiffEq.jl

This PR moves the default SDE solver implementation from DifferentialEquations.jl to StochasticDiffEq.jl, following the pattern established in SciML/DelayDiffEq.jl#326 and SciML/DelayDiffEq.jl#334.

## Changes
- Added `src/default_sde_alg.jl` containing the default algorithm selection logic
- Implemented `__init` and `__solve` dispatches for `SDEProblem` with `Nothing` algorithm
- Added `get_alg_hints` helper function for extracting algorithm hints from kwargs
- Added comprehensive tests in `test/default_solver_test.jl`
- Updated module to include the new default algorithm file

## Default Algorithm Behavior
When no algorithm is specified, the solver now automatically selects:
- SOSRI() as the standard default
- RKMilCommute() for commutative noise
- ImplicitRKMil() for stiff problems or non-identity mass matrices
- RKMil() for Stratonovich interpretation
- LambaEM() / LambaEulerHeun() for non-diagonal noise
- ISSEM() / ImplicitEulerHeun() for stiff non-diagonal problems
- SOSRA() / SKenCarp() for additive noise

## Test Plan
- [x] Added tests verifying default solver dispatch
- [x] Tests verify correct algorithm selection for various problem types
- [x] All tests pass locally

This is part of the ongoing effort to modularize DifferentialEquations.jl by moving default solvers to their respective packages.

🤖 Generated with [Claude Code](https://claude.ai/claude-code)

Co-Authored-By: Claude <[email protected]>
  • Loading branch information
ChrisRackauckas and claude committed Oct 10, 2025
commit 8814ce157e4af1db3b3b5543721725b1ce2fd2d9
1 change: 1 addition & 0 deletions src/StochasticDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ include("iterated_integrals.jl")
include("SROCK_utils.jl")
include("composite_algs.jl")
include("weak_utils.jl")
include("default_sde_alg.jl")

export StochasticDiffEqAlgorithm, StochasticDiffEqAdaptiveAlgorithm,
StochasticCompositeAlgorithm
Expand Down
77 changes: 77 additions & 0 deletions src/default_sde_alg.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Default algorithm selection for SDEs
# Moved from DifferentialEquations.jl as part of modularization effort

using LinearAlgebra: I

# Helper function to extract alg_hints from keyword arguments
function get_alg_hints(o)
:alg_hints ∈ keys(o) ? alg_hints = o[:alg_hints] : alg_hints = Symbol[:auto]
end

function default_algorithm(
prob::DiffEqBase.AbstractSDEProblem{uType, tType, isinplace, ND};
kwargs...) where {uType, tType, isinplace, ND}
o = Dict{Symbol, Any}(kwargs)
alg = SOSRI() # Standard default

alg_hints = get_alg_hints(o)

if :commutative ∈ alg_hints
alg = RKMilCommute()
end

is_stiff = :stiff ∈ alg_hints
is_stratonovich = :stratonovich ∈ alg_hints
if is_stiff || prob.f.mass_matrix !== I
alg = ImplicitRKMil(autodiff = false)
end

if is_stratonovich
if is_stiff || prob.f.mass_matrix !== I
alg = ImplicitRKMil(autodiff = false,
interpretation = SciMLBase.AlgorithmInterpretation.Stratonovich)
else
alg = RKMil(interpretation = SciMLBase.AlgorithmInterpretation.Stratonovich)
end
end

if prob.noise_rate_prototype != nothing || prob.noise != nothing
if is_stratonovich
if is_stiff || prob.f.mass_matrix !== I
alg = ImplicitEulerHeun(autodiff = false)
else
alg = LambaEulerHeun()
end
else
if is_stiff || prob.f.mass_matrix !== I
alg = ISSEM(autodiff = false)
else
alg = LambaEM()
end
end
end

if :additive ∈ alg_hints
if is_stiff || prob.f.mass_matrix !== I
alg = SKenCarp(autodiff = false)
else
alg = SOSRA()
end
end

return alg
end

# Dispatch for __init with Nothing algorithm - use default
function DiffEqBase.__init(
prob::DiffEqBase.AbstractSDEProblem, ::Nothing, args...; kwargs...)
alg = default_algorithm(prob; kwargs...)
DiffEqBase.__init(prob, alg, args...; kwargs...)
end

# Dispatch for __solve with Nothing algorithm - use default
function DiffEqBase.__solve(
prob::DiffEqBase.AbstractSDEProblem, ::Nothing, args...; kwargs...)
alg = default_algorithm(prob; kwargs...)
DiffEqBase.__solve(prob, alg, args...; kwargs...)
end
58 changes: 58 additions & 0 deletions test/default_solver_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
using StochasticDiffEq, Test
import SciMLBase
using Random

# Additive SDE test problem
f_additive(u, p, t) = @. p[2] / sqrt(1 + t) - u / (2 * (1 + t))
σ_additive(u, p, t) = @. p[1] * p[2] / sqrt(1 + t)
p = (0.1, 0.05)
additive_analytic(u0, p, t, W) = @. u0 / sqrt(1 + t) + p[2] * (t + p[1] * W) / sqrt(1 + t)
ff_additive = SDEFunction(f_additive, σ_additive, analytic = additive_analytic)
prob_sde_additive = SDEProblem(ff_additive, σ_additive, 1.0, (0.0, 1.0), p)

Random.seed!(100)

# Test default (no algorithm specified) - should use SOSRI
prob = prob_sde_additive
sol = solve(prob, dt = 1 / 2^(3))
@test sol.alg isa SOSRI

# Test with :additive hint - should use SOSRA
sol = solve(prob, dt = 1 / 2^(3), alg_hints = [:additive])
@test sol.alg isa SOSRA

# Test with :stratonovich hint - should use RKMil with Stratonovich interpretation
sol = solve(prob, dt = 1 / 2^(3), alg_hints = [:stratonovich])
@test SciMLBase.alg_interpretation(sol.alg) ==
SciMLBase.AlgorithmInterpretation.Stratonovich
@test sol.alg isa RKMil

# Non-diagonal noise test problem
f = (du, u, p, t) -> du .= 1.01u
g = function (du, u, p, t)
du[1, 1] = 0.3u[1]
du[1, 2] = 0.6u[1]
du[1, 3] = 0.9u[1]
du[1, 4] = 0.12u[2]
du[2, 1] = 1.2u[1]
du[2, 2] = 0.2u[2]
du[2, 3] = 0.3u[2]
du[2, 4] = 1.8u[2]
end
prob = SDEProblem(f, g, ones(2), (0.0, 1.0), noise_rate_prototype = zeros(2, 4))

# Test default with non-diagonal noise - should use LambaEM
sol = solve(prob, dt = 1 / 2^(3))
@test sol.alg isa LambaEM

# Test with :stiff hint - should use ISSEM
sol = solve(prob, dt = 1 / 2^(3), alg_hints = [:stiff])
@test sol.alg isa ISSEM

# Test with :additive hint - should still use SOSRA (overrides non-diagonal)
sol = solve(prob, dt = 1 / 2^(3), alg_hints = [:additive])
@test sol.alg isa SOSRA

# Test with :stratonovich hint - should use LambaEulerHeun
sol = solve(prob, dt = 1 / 2^(3), alg_hints = [:stratonovich])
@test sol.alg isa LambaEulerHeun
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ const is_APPVEYOR = Sys.iswindows() && haskey(ENV, "APPVEYOR")

@time begin
if GROUP == "All" || GROUP == "Interface1"
@time @safetestset "Default Solver Tests" begin
include("default_solver_test.jl")
end
@time @safetestset "First Rand Tests" begin
include("first_rand_test.jl")
end
Expand Down
Loading