Skip to content

Commit 9411367

Browse files
committed
some minor tweaks towards making n-body example working again
1 parent 0b95d58 commit 9411367

9 files changed

Lines changed: 258 additions & 72 deletions

File tree

SciLean/Analysis/Calculus/HasRevFDeriv.lean

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,6 +1013,36 @@ theorem Norm2.norm2.arg_a0.HasRevFDerivUpdate_simple_rule :
10131013
case adjoint => intro; dsimp; data_synth
10141014
case simp => funext x; simp; funext dr x'; module
10151015

1016+
set_option linter.unusedVariables false in
1017+
@[data_synth]
1018+
theorem norm2.arg_a0.HasRevFDeriv_rule
1019+
(f : X → Y) (f')
1020+
(hf : HasRevFDeriv R f f') (hf' : ∀ x, f x ≠ 0) :
1021+
HasRevFDeriv R (fun x => ‖f x‖₂[R])
1022+
(fun x =>
1023+
let' (y,df) := f' x
1024+
let ynorm := ‖y‖₂[R]
1025+
(ynorm, fun dr =>
1026+
let dy := (dr * ynorm⁻¹) • y
1027+
let dx := df dy
1028+
dx)) := by
1029+
sorry_proof
1030+
1031+
set_option linter.unusedVariables false in
1032+
@[data_synth]
1033+
theorem norm2.arg_a0.HasRevFDerivUpdate_rule
1034+
(f : X → Y) (f')
1035+
(hf : HasRevFDerivUpdate R f f') (hf' : ∀ x, f x ≠ 0) :
1036+
HasRevFDerivUpdate R (fun x => ‖f x‖₂[R])
1037+
(fun x =>
1038+
let' (y,df) := f' x
1039+
let ynorm := ‖y‖₂[R]
1040+
(ynorm, fun dr dx =>
1041+
let dy := (dr * ynorm⁻¹) • y
1042+
let dx := df dy dx
1043+
dx)) := by
1044+
sorry_proof
1045+
10161046
end OverReals
10171047

10181048

SciLean/Analysis/ODE/OdeSolve.lean

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,16 @@ structure HasUniqueOdeSolution (f : R → X → X) extends HasOdeSolution f : Pr
2626
uniq : ∀ t₀ x₀ x x', IsOdeSolution f t₀ x₀ x → IsOdeSolution f t₀ x₀ x' → x = x'
2727

2828
open Classical in
29+
/-- Solution of ordinary differentiale equation.
30+
31+
Function `x := fun t => odeSolve f t₀ t x₀` satisfies ODE
32+
```
33+
∂ x t = f t (x t)
34+
```
35+
with initial condition `x t₀ = x₀`.
36+
37+
Junk value is returned if `f` does define ODE with an unique solution.
38+
-/
2939
noncomputable
3040
def odeSolve (f : R → X → X) (t₀ t : R) (x₀ : X) : X :=
3141
if h : HasUniqueOdeSolution f -- TODO: can we reduce it to just HasOdeSolution?

SciLean/Numerics/Optimization/Optimjl.lean

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -41,37 +41,3 @@ theorem argmin_eq_limit_optimize
4141
let f' := holdLet <| revFDeriv R f
4242
let r := optimize {f:=f,f':=f',hf:=sorry_proof} (AbstractOptimizer.setOptions X method opts) x₀
4343
r.minimizer := sorry_proof
44-
45-
46-
#check LBFGS Float 1
47-
48-
set_default_scalar Float
49-
50-
open Scalar
51-
52-
approx mySqrt_v1 (y : Float) := argmin (x : Float), ‖x^2-y‖₂²
53-
by
54-
rw[argmin_eq_limit_optimize (method := (default : LBFGS Float 1)) (x₀:=1)]
55-
56-
approx_limit options sorry_proof
57-
58-
conv in (revFDeriv _ _) => autodiff
59-
60-
61-
def mySqrt_v2 (y : Float) :=
62-
let r := optimize (d:={f := fun (x : Float) => ‖x^2-y‖₂²,
63-
f' := _,
64-
hf := by data_synth => enter[3]; lsimp})
65-
(method := (default : LBFGS Float 1))
66-
(x₀ := 1)
67-
r.minimizer
68-
69-
70-
71-
/-- info: 1.414212 -/
72-
#guard_msgs in
73-
#eval mySqrt_v1 2 ({},())
74-
75-
/-- info: 1.414214 -/
76-
#guard_msgs in
77-
#eval mySqrt_v1 2 ({x_abstol := 1e-16, g_abstol := 0},())

SciLean/Tactic/DataSynth/Elab.lean

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ syntax (name:=data_synth_conv) "data_synth" optConfig (discharger)? : conv
3333
| none => fun _ => return none
3434
| some stx => Mathlib.Meta.FunProp.tacticToDischarge ⟨stx.raw[3]⟩
3535

36-
let (r?,_) ← dataSynth g |>.run {config := cfg, discharge := disch} |>.run stateRef
36+
let (r?,_) ← dataSynth g |>.run {config := cfg, discharge := fun e => do disch e} |>.run stateRef
3737
|>.run (← Simp.mkDefaultMethods).toMethodsRef
3838
|>.run (← Simp.mkContext (config := cfg.toConfig) (simpTheorems := #[← getSimpTheorems]))
3939
|>.run {}
@@ -81,7 +81,7 @@ syntax (name:=data_synth_tac) "data_synth" optConfig (discharger)? ("=>" convSeq
8181

8282
let stateRef : IO.Ref DataSynth.State ← IO.mkRef {}
8383

84-
let (r?,_) ← dataSynth g |>.run {config := cfg, discharge := disch} |>.run stateRef
84+
let (r?,_) ← dataSynth g |>.run {config := cfg, discharge := fun e => do disch e} |>.run stateRef
8585
|>.run (← Simp.mkDefaultMethods).toMethodsRef
8686
|>.run (← Simp.mkContext (config := cfg.toConfig) (simpTheorems := #[← getSimpTheorems]))
8787
|>.run {}

SciLean/Tactic/DataSynth/Types.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ structure Config extends DataSynthConfig, Simp.Config
120120
structure Context where
121121
config : Config := {}
122122
normalize : Expr → Simp.SimpM Simp.Result := fun e => return {expr := e}
123-
discharge : Expr → MetaM (Option Expr) := fun _ => return .none
123+
discharge : Expr → SimpM (Option Expr) := fun _ => return .none
124124

125125
structure State where
126126
numSteps := 0

Test/data_synth/get_elem.lean

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import SciLean.Data.ArrayOperations.Operations.GetElem
2+
import SciLean.Data.DataArray.Float
3+
import SciLean.Data.DataArray.VectorType
4+
5+
open SciLean
6+
7+
variable {n} (i : Fin n) (j : Fin 3)
8+
9+
/--
10+
info: HasRevFDeriv Float (fun x => x[i]) fun x =>
11+
(x[i], fun xi =>
12+
let x' := setElem 0 i xi True.intro;
13+
x') : Prop
14+
-/
15+
#guard_msgs in
16+
#check (HasRevFDeriv Float (fun x : Float^[n] => x[i]) _ ) rewrite_by data_synth
17+
18+
19+
/--
20+
info: HasRevFDeriv Float (fun x => x[(i, j)]) fun x =>
21+
(x[(i, j)], fun xi =>
22+
let x' := setElem 0 (i, j) xi True.intro;
23+
x') : Prop
24+
-/
25+
#guard_msgs in
26+
#check (HasRevFDeriv Float (fun x : Float^[n,3] => x[i,j]) _ ) rewrite_by data_synth
27+
28+
29+
/--
30+
info: HasRevFDeriv Float (fun x => x[(i, j)]) fun x =>
31+
(x[(i, j)], fun xi =>
32+
let x' := setElem 0 (i, j) xi True.intro;
33+
x') : Prop
34+
-/
35+
#guard_msgs in
36+
#check (HasRevFDeriv Float (fun x : Float^[3]^[n] => x[i,j]) _ ) rewrite_by data_synth
37+
38+
39+
-- this is broken!!!
40+
-- some serious issue with type classes :(
41+
-- #check (HasRevFDeriv Float (fun x : Float^[3]^[n] => x[i]) _ ) rewrite_by data_synth

examples/NBody.lean

Lines changed: 171 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,180 @@
1-
import SciLean.Mechanics
2-
import SciLean.Operators.ODE
3-
import SciLean.Solver
4-
import SciLean.Tactic.LiftLimit
5-
import SciLean.Tactic.FinishImpl
6-
import SciLean.Data.PowType
7-
import SciLean.Core.Extra
8-
import SciLean.Functions
1+
import SciLean
2+
3+
import ProofWidgets.Data.Svg
4+
import ProofWidgets.Component.HtmlDisplay
5+
6+
open Lean ProofWidgets
7+
98

109
open SciLean
1110

1211
set_option synthInstance.maxSize 2048
1312
set_option synthInstance.maxHeartbeats 500000
1413
set_option maxHeartbeats 500000
1514

16-
def H (n : Nat) (ε : ℝ) (m : Idx n → ℝ) (x p : ((ℝ^(3:Nat))^n)) : ℝ :=
17-
(∑ i, (1/(2*m i)) * ∥p[i]∥²)
18-
+
19-
(∑ i j, (m i*m j) * ∥x[i] - x[j]∥^{(-1:ℝ),ε})
20-
argument p [Fact (ε≠0)] [Fact (n≠0)]
21-
isSmooth, hasAdjDiff, adjDiff
22-
argument x [Fact (ε≠0)] [Fact (n≠0)]
23-
isSmooth, hasAdjDiff,
24-
adjDiff by
25-
simp[H]
26-
simp [sum_into_lambda]
27-
simp [← sum_of_add]
28-
29-
def solver (steps : ℕ) (n : Nat) [Fact (n≠0)] (ε : ℝ) [Fact (ε≠0)] (m : Idx n → ℝ)
30-
: Impl (ode_solve (HamiltonianSystem (H n ε m))) :=
15+
macro (priority:=high) A:term noWs "[" i:term "," ":" "]" : term => `(MatrixType.row $A $i)
16+
macro (priority:=high) A:term noWs "[" ":" "," j:term "]" : term => `(MatrixType.col $A $j)
17+
18+
axiom unsafeNonzero {α} [Zero α] (a : α) : a ≠ 0
19+
20+
macro "unsafeAD" : tactic =>
21+
`(tactic| (intros; simp only [not_false_eq_true, ne_eq, unsafeNonzero]))
22+
23+
24+
open Lean Meta
25+
instance : MonadLift Tactic.DataSynth.DataSynthM SimpM where
26+
monadLift e := do
27+
let disch? := (← Simp.getMethods).discharge?
28+
-- discharge? : Expr → SimpM (Option Expr) := fun _ => return none
29+
let r := e { discharge := disch? } (← ST.mkRef {}) (← ST.mkRef {})
30+
r
31+
32+
33+
theorem revFDeriv_from_hasRevFDeriv {K} [RCLike K]
34+
{X} [NormedAddCommGroup X] [AdjointSpace K X]
35+
{Y} [NormedAddCommGroup Y] [AdjointSpace K Y]
36+
{f : X → Y} {f'} (hf : HasRevFDeriv K f f') :
37+
revFDeriv K f = f' := sorry_proof
38+
39+
40+
open Lean Meta in
41+
/-- Compute `revFDeriv R f` with calling data_synth on `HasRevFDeriv R f ?f'`. -/
42+
simproc_decl revFDeriv_simproc (revFDeriv _ _) := fun e => do
43+
44+
-- get field and function to differentiate
45+
let K := e.getArg! 0
46+
let f := e.appArg!
47+
48+
-- craft `HasRevFDeriv K f ?f'`
49+
let goal ← mkAppM ``HasRevFDeriv #[K,f]
50+
let (xs,_,_) ← forallMetaTelescope (← inferType goal)
51+
let f' := xs[0]!
52+
let goal := goal.app f'
53+
54+
-- extract data_synth goal
55+
let .some goal ← Tactic.DataSynth.isDataSynthGoal? goal
56+
| throwError "something went really wrong"
57+
58+
-- run data_synth
59+
let .some r ← Tactic.DataSynth.dataSynth goal
60+
| return .continue
61+
62+
let f'' := r.xs[0]!
63+
let prf ← mkAppM ``revFDeriv_from_hasRevFDeriv #[r.proof]
64+
-- IO.println (← ppExpr (← inferType prf))
65+
66+
-- let eq ← mkEq e f''
67+
-- let prf ← mkSorry eq false
68+
69+
return .visit { expr := f'', proof? := prf }
70+
71+
72+
set_default_scalar Float
73+
open Scalar
74+
example : (<∂ x : Float, x) = fun x => (x, fun dx => dx) := by simp[revFDeriv_simproc]
75+
example : (<∂ x : Float, x*x) = fun x => (x*x, fun dx => x*dx+x*dx) := by simp[revFDeriv_simproc]
76+
example : (<∂ x : Float, exp x) = fun x => (exp x, fun dx => dx*exp x) := by simp[revFDeriv_simproc]
77+
78+
def H {n} (m : Fin n → Float) (x p : Float^[n,2]) : Float :=
79+
(∑ i, (1/(2*m i)) * ‖p[i,:]‖₂²)
80+
-
81+
(∑ (i : Fin n) (j : Fin i.1),
82+
let j := j.castLT (by omega)
83+
(m i*m j) * ‖x[i,:] - x[j,:]‖₂⁻¹)
84+
85+
86+
87+
#check (<∂ x':=(⊞[1.0,2.0,3.0]:Float3), ‖x'‖₂⁻¹) rewrite_by
88+
unfold fgradient
89+
lsimp (disch:=unsafeAD) only [simp_core,revFDeriv_simproc]
90+
91+
variable (x p : Float^[n,3]) (ε : Float)
92+
93+
-- set_option trace.Meta.Tactic.data_synth true in
94+
-- #check (<∂ x':=x, ∑ i j, ‖x'[i,:] - x'[j,:]‖₂⁻¹) rewrite_by
95+
-- unfold fgradient
96+
-- lsimp (disch:=unsafeAD) only [simp_core, revFDeriv_simproc]
97+
98+
-- #check (<∂ x':=x, H (fun _ => 1) x' p) rewrite_by
99+
-- unfold fgradient H
100+
-- lsimp (disch:=unsafeAD) only [simp_core, revFDeriv_simproc]
101+
102+
-- #check (<∂ p':=p, H (fun _ => 1) x p') rewrite_by
103+
-- unfold fgradient H
104+
-- lsimp (disch:=unsafeAD) only [simp_core, revFDeriv_simproc]
105+
106+
107+
-- #check odeSolve.arg_x₀.revFDeriv_rule
108+
109+
110+
approx solver (m : Fin n → Float)
111+
:= odeSolve (fun (t : Float) (x,p) => ( ∇ (p':=p), H m x p',
112+
-∇ (x':=x), H m x' p))
31113
by
32-
-- Unfold Hamiltonian definition and compute gradients
33-
simp[HamiltonianSystem]
34-
114+
-- Unfold Hamiltonian and compute gradients
115+
unfold H fgradient
116+
lsimp (disch:=unsafeAD) only [simp_core,revFDeriv_simproc]
117+
35118
-- Apply RK4 method
36-
rw [ode_solve_fixed_dt runge_kutta4_step]
37-
lift_limit steps "Number of ODE solver steps."; admit; simp
38-
39-
finish_impl
119+
simp_rw (config:={zeta:=false}) [odeSolve_fixed_dt rungeKutta4 sorry_proof]
120+
121+
-- todo: make approx_limit ignore leading let bindings
122+
approx_limit n sorry_proof
123+
124+
125+
126+
#eval! solver (fun _ : Fin 2 => 1) (10,()) 0 0.7
127+
(⊞[-1.0,0;1.0,0],⊞[0.0,-1.0;0,1.0])
128+
129+
130+
def generateData (m : ℕ) : Float^[2,2]^[m] := Id.run do
131+
132+
let mut data : Float^[2,2]^[m] := 0
133+
let Δt := 0.01
134+
135+
-- initial state
136+
let mut x := ⊞[-1.0, 0.0;
137+
1.0, 0.0]
138+
139+
let mut p := ⊞[ 0.0,-0.3;
140+
0.0, 0.3]
141+
142+
for h : i in [0:m] do
143+
let i : Fin m := ⟨i, sorry_proof⟩
144+
data[i] := x
145+
(x,p) := solver 1 (1,()) 0 Δt (x,p)
146+
147+
return data
148+
149+
150+
open ProofWidgets Svg
151+
152+
private def frame : Frame where
153+
xmin := -2
154+
ymin := -2
155+
xSize := 4
156+
width := 400
157+
height := 400
158+
159+
open IndexType
160+
instance {I} [IndexType I] : ToJson (Float^[I]) where
161+
toJson x := toJson (Array.ofFn (fun i => x[fromFin i]))
162+
163+
instance {I} [IndexType I] : FromJson (Float^[I]) where
164+
fromJson? j :=
165+
match fromJson? (α:=Array Float) j with
166+
| Except.error e => Except.error e
167+
| Except.ok data =>
168+
Except.ok (⊞ (i : I) => data[toFin i]!)
169+
170+
171+
instance (f : Frame) : Coe (Float^[2]) (Point f) := ⟨fun x => .abs x[0] x[1]⟩
172+
173+
private def svg : Svg frame :=
174+
let data := generateData 1000
175+
176+
{ elements :=
177+
#[polyline (Array.ofFn (fun i => data[i][0,:])) |>.setStroke (0.8,0.8,0.2) (.px 1) ,
178+
polyline (Array.ofFn (fun i => data[i][1,:])) |>.setStroke (0.2,0.8,0.8) (.px 1) ] }
179+
180+
#html svg.toHtml

examples/WaveEquation.lean

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import SciLean
2-
import SciLean
32

43
open SciLean
54

0 commit comments

Comments
 (0)