|
1 | | -import SciLean.Mechanics |
2 | | -import SciLean.Operators.ODE |
3 | | -import SciLean.Solver |
4 | | -import SciLean.Tactic.LiftLimit |
5 | | -import SciLean.Tactic.FinishImpl |
6 | | -import SciLean.Data.PowType |
7 | | -import SciLean.Core.Extra |
8 | | -import SciLean.Functions |
| 1 | +import SciLean |
| 2 | + |
| 3 | +import ProofWidgets.Data.Svg |
| 4 | +import ProofWidgets.Component.HtmlDisplay |
| 5 | + |
| 6 | +open Lean ProofWidgets |
| 7 | + |
9 | 8 |
|
10 | 9 | open SciLean |
11 | 10 |
|
12 | 11 | set_option synthInstance.maxSize 2048 |
13 | 12 | set_option synthInstance.maxHeartbeats 500000 |
14 | 13 | set_option maxHeartbeats 500000 |
15 | 14 |
|
16 | | -def H (n : Nat) (ε : ℝ) (m : Idx n → ℝ) (x p : ((ℝ^(3:Nat))^n)) : ℝ := |
17 | | - (∑ i, (1/(2*m i)) * ∥p[i]∥²) |
18 | | - + |
19 | | - (∑ i j, (m i*m j) * ∥x[i] - x[j]∥^{(-1:ℝ),ε}) |
20 | | -argument p [Fact (ε≠0)] [Fact (n≠0)] |
21 | | - isSmooth, hasAdjDiff, adjDiff |
22 | | -argument x [Fact (ε≠0)] [Fact (n≠0)] |
23 | | - isSmooth, hasAdjDiff, |
24 | | - adjDiff by |
25 | | - simp[H] |
26 | | - simp [sum_into_lambda] |
27 | | - simp [← sum_of_add] |
28 | | - |
29 | | -def solver (steps : ℕ) (n : Nat) [Fact (n≠0)] (ε : ℝ) [Fact (ε≠0)] (m : Idx n → ℝ) |
30 | | - : Impl (ode_solve (HamiltonianSystem (H n ε m))) := |
| 15 | +macro (priority:=high) A:term noWs "[" i:term "," ":" "]" : term => `(MatrixType.row $A $i) |
| 16 | +macro (priority:=high) A:term noWs "[" ":" "," j:term "]" : term => `(MatrixType.col $A $j) |
| 17 | + |
| 18 | +axiom unsafeNonzero {α} [Zero α] (a : α) : a ≠ 0 |
| 19 | + |
| 20 | +macro "unsafeAD" : tactic => |
| 21 | + `(tactic| (intros; simp only [not_false_eq_true, ne_eq, unsafeNonzero])) |
| 22 | + |
| 23 | + |
| 24 | +open Lean Meta |
| 25 | +instance : MonadLift Tactic.DataSynth.DataSynthM SimpM where |
| 26 | + monadLift e := do |
| 27 | + let disch? := (← Simp.getMethods).discharge? |
| 28 | + -- discharge? : Expr → SimpM (Option Expr) := fun _ => return none |
| 29 | + let r := e { discharge := disch? } (← ST.mkRef {}) (← ST.mkRef {}) |
| 30 | + r |
| 31 | + |
| 32 | + |
| 33 | +theorem revFDeriv_from_hasRevFDeriv {K} [RCLike K] |
| 34 | + {X} [NormedAddCommGroup X] [AdjointSpace K X] |
| 35 | + {Y} [NormedAddCommGroup Y] [AdjointSpace K Y] |
| 36 | + {f : X → Y} {f'} (hf : HasRevFDeriv K f f') : |
| 37 | + revFDeriv K f = f' := sorry_proof |
| 38 | + |
| 39 | + |
| 40 | +open Lean Meta in |
| 41 | +/-- Compute `revFDeriv R f` with calling data_synth on `HasRevFDeriv R f ?f'`. -/ |
| 42 | +simproc_decl revFDeriv_simproc (revFDeriv _ _) := fun e => do |
| 43 | + |
| 44 | + -- get field and function to differentiate |
| 45 | + let K := e.getArg! 0 |
| 46 | + let f := e.appArg! |
| 47 | + |
| 48 | + -- craft `HasRevFDeriv K f ?f'` |
| 49 | + let goal ← mkAppM ``HasRevFDeriv #[K,f] |
| 50 | + let (xs,_,_) ← forallMetaTelescope (← inferType goal) |
| 51 | + let f' := xs[0]! |
| 52 | + let goal := goal.app f' |
| 53 | + |
| 54 | + -- extract data_synth goal |
| 55 | + let .some goal ← Tactic.DataSynth.isDataSynthGoal? goal |
| 56 | + | throwError "something went really wrong" |
| 57 | + |
| 58 | + -- run data_synth |
| 59 | + let .some r ← Tactic.DataSynth.dataSynth goal |
| 60 | + | return .continue |
| 61 | + |
| 62 | + let f'' := r.xs[0]! |
| 63 | + let prf ← mkAppM ``revFDeriv_from_hasRevFDeriv #[r.proof] |
| 64 | + -- IO.println (← ppExpr (← inferType prf)) |
| 65 | + |
| 66 | + -- let eq ← mkEq e f'' |
| 67 | + -- let prf ← mkSorry eq false |
| 68 | + |
| 69 | + return .visit { expr := f'', proof? := prf } |
| 70 | + |
| 71 | + |
| 72 | +set_default_scalar Float |
| 73 | +open Scalar |
| 74 | +example : (<∂ x : Float, x) = fun x => (x, fun dx => dx) := by simp[revFDeriv_simproc] |
| 75 | +example : (<∂ x : Float, x*x) = fun x => (x*x, fun dx => x*dx+x*dx) := by simp[revFDeriv_simproc] |
| 76 | +example : (<∂ x : Float, exp x) = fun x => (exp x, fun dx => dx*exp x) := by simp[revFDeriv_simproc] |
| 77 | + |
| 78 | +def H {n} (m : Fin n → Float) (x p : Float^[n,2]) : Float := |
| 79 | + (∑ i, (1/(2*m i)) * ‖p[i,:]‖₂²) |
| 80 | + - |
| 81 | + (∑ (i : Fin n) (j : Fin i.1), |
| 82 | + let j := j.castLT (by omega) |
| 83 | + (m i*m j) * ‖x[i,:] - x[j,:]‖₂⁻¹) |
| 84 | + |
| 85 | + |
| 86 | + |
| 87 | +#check (<∂ x':=(⊞[1.0,2.0,3.0]:Float3), ‖x'‖₂⁻¹) rewrite_by |
| 88 | + unfold fgradient |
| 89 | + lsimp (disch:=unsafeAD) only [simp_core,revFDeriv_simproc] |
| 90 | + |
| 91 | +variable (x p : Float^[n,3]) (ε : Float) |
| 92 | + |
| 93 | +-- set_option trace.Meta.Tactic.data_synth true in |
| 94 | +-- #check (<∂ x':=x, ∑ i j, ‖x'[i,:] - x'[j,:]‖₂⁻¹) rewrite_by |
| 95 | +-- unfold fgradient |
| 96 | +-- lsimp (disch:=unsafeAD) only [simp_core, revFDeriv_simproc] |
| 97 | + |
| 98 | +-- #check (<∂ x':=x, H (fun _ => 1) x' p) rewrite_by |
| 99 | +-- unfold fgradient H |
| 100 | +-- lsimp (disch:=unsafeAD) only [simp_core, revFDeriv_simproc] |
| 101 | + |
| 102 | +-- #check (<∂ p':=p, H (fun _ => 1) x p') rewrite_by |
| 103 | +-- unfold fgradient H |
| 104 | +-- lsimp (disch:=unsafeAD) only [simp_core, revFDeriv_simproc] |
| 105 | + |
| 106 | + |
| 107 | +-- #check odeSolve.arg_x₀.revFDeriv_rule |
| 108 | + |
| 109 | + |
| 110 | +approx solver (m : Fin n → Float) |
| 111 | + := odeSolve (fun (t : Float) (x,p) => ( ∇ (p':=p), H m x p', |
| 112 | + -∇ (x':=x), H m x' p)) |
31 | 113 | by |
32 | | - -- Unfold Hamiltonian definition and compute gradients |
33 | | - simp[HamiltonianSystem] |
34 | | - |
| 114 | + -- Unfold Hamiltonian and compute gradients |
| 115 | + unfold H fgradient |
| 116 | + lsimp (disch:=unsafeAD) only [simp_core,revFDeriv_simproc] |
| 117 | + |
35 | 118 | -- Apply RK4 method |
36 | | - rw [ode_solve_fixed_dt runge_kutta4_step] |
37 | | - lift_limit steps "Number of ODE solver steps."; admit; simp |
38 | | - |
39 | | - finish_impl |
| 119 | + simp_rw (config:={zeta:=false}) [odeSolve_fixed_dt rungeKutta4 sorry_proof] |
| 120 | + |
| 121 | + -- todo: make approx_limit ignore leading let bindings |
| 122 | + approx_limit n sorry_proof |
| 123 | + |
| 124 | + |
| 125 | + |
| 126 | +#eval! solver (fun _ : Fin 2 => 1) (10,()) 0 0.7 |
| 127 | + (⊞[-1.0,0;1.0,0],⊞[0.0,-1.0;0,1.0]) |
| 128 | + |
| 129 | + |
| 130 | +def generateData (m : ℕ) : Float^[2,2]^[m] := Id.run do |
| 131 | + |
| 132 | + let mut data : Float^[2,2]^[m] := 0 |
| 133 | + let Δt := 0.01 |
| 134 | + |
| 135 | + -- initial state |
| 136 | + let mut x := ⊞[-1.0, 0.0; |
| 137 | + 1.0, 0.0] |
| 138 | + |
| 139 | + let mut p := ⊞[ 0.0,-0.3; |
| 140 | + 0.0, 0.3] |
| 141 | + |
| 142 | + for h : i in [0:m] do |
| 143 | + let i : Fin m := ⟨i, sorry_proof⟩ |
| 144 | + data[i] := x |
| 145 | + (x,p) := solver 1 (1,()) 0 Δt (x,p) |
| 146 | + |
| 147 | + return data |
| 148 | + |
| 149 | + |
| 150 | +open ProofWidgets Svg |
| 151 | + |
| 152 | +private def frame : Frame where |
| 153 | + xmin := -2 |
| 154 | + ymin := -2 |
| 155 | + xSize := 4 |
| 156 | + width := 400 |
| 157 | + height := 400 |
| 158 | + |
| 159 | +open IndexType |
| 160 | +instance {I} [IndexType I] : ToJson (Float^[I]) where |
| 161 | + toJson x := toJson (Array.ofFn (fun i => x[fromFin i])) |
| 162 | + |
| 163 | +instance {I} [IndexType I] : FromJson (Float^[I]) where |
| 164 | + fromJson? j := |
| 165 | + match fromJson? (α:=Array Float) j with |
| 166 | + | Except.error e => Except.error e |
| 167 | + | Except.ok data => |
| 168 | + Except.ok (⊞ (i : I) => data[toFin i]!) |
| 169 | + |
| 170 | + |
| 171 | +instance (f : Frame) : Coe (Float^[2]) (Point f) := ⟨fun x => .abs x[0] x[1]⟩ |
| 172 | + |
| 173 | +private def svg : Svg frame := |
| 174 | + let data := generateData 1000 |
| 175 | + |
| 176 | + { elements := |
| 177 | + #[polyline (Array.ofFn (fun i => data[i][0,:])) |>.setStroke (0.8,0.8,0.2) (.px 1) , |
| 178 | + polyline (Array.ofFn (fun i => data[i][1,:])) |>.setStroke (0.2,0.8,0.8) (.px 1) ] } |
| 179 | + |
| 180 | +#html svg.toHtml |
0 commit comments