diff --git a/Project.toml b/Project.toml index 71f7f70b7..23fa54558 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "EnsembleKalmanProcesses" uuid = "aa8a2aa5-91d8-4396-bcef-d4f2ec43552d" authors = ["CLIMA contributors "] -version = "2.4.1" +version = "2.4.2" [deps] Convex = "f65535da-76fb-5f13-bab9-19810c17039a" @@ -23,6 +23,12 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" TSVD = "9449cd9e-2762-5aa3-a617-5413e99d722e" +[weakdeps] +Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" + +[extensions] +EnsembleKalmanProcessesMakieExt = "Makie" + [compat] Convex = "0.15, 0.16" Distributions = "0.24.14, 0.25" @@ -46,9 +52,10 @@ TSVD = "0.4.4" julia = "1.5" [extras] +CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["StableRNGs", "Test", "Plots"] +test = ["CairoMakie", "StableRNGs", "Test", "Plots"] diff --git a/docs/Project.toml b/docs/Project.toml index dc84d5869..a351083e7 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,5 +1,6 @@ [deps] CLIMAParameters = "6eacf6c3-8458-43b9-ae03-caf5306d3d53" +CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" CloudMicrophysics = "6a9e3e04-43cd-43ba-94b9-e8782df3c71b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" @@ -9,4 +10,4 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Thermodynamics = "b60c26fb-14c3-4610-9d3e-2d17fe7ff00c" [compat] -CloudMicrophysics = "0.5" \ No newline at end of file +CloudMicrophysics = "0.5" diff --git a/docs/make.jl b/docs/make.jl index 437524527..74c40b54c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -3,6 +3,7 @@ using EnsembleKalmanProcesses, Documenter, Plots, # so that Literate.jl does not capture precompilation output + CairoMakie, Literate prepend!(LOAD_PATH, [joinpath(@__DIR__, "..")]) @@ -56,6 +57,7 @@ api = [ "SparseInversion" => "API/SparseInversion.md", "TOML Interface" => "API/TOMLInterface.md", "Localizers" => "API/Localizers.md", + "Visualize" => "API/Visualize.md", ] examples = [ @@ -89,6 +91,7 @@ pages = [ "Inflation" => "inflation.md", "Parallelism and HPC" => "parallel_hpc.md", "Internal data representation" => "internal_data_representation.md", + "Visualization" => "visualization.md", "Troubleshooting" => "troubleshooting.md", "API" => api, "Contributing" => "contributing.md", @@ -104,7 +107,7 @@ makedocs( authors = "CliMA Contributors", format = format, pages = pages, - modules = [EnsembleKalmanProcesses], + modules = [EnsembleKalmanProcesses, Base.get_extension(EnsembleKalmanProcesses, :EnsembleKalmanProcessesMakieExt)], doctest = true, clean = true, checkdocs = :none, diff --git a/docs/src/API/EnsembleKalmanProcess.md b/docs/src/API/EnsembleKalmanProcess.md index 195e318e3..9f6c8713b 100644 --- a/docs/src/API/EnsembleKalmanProcess.md +++ b/docs/src/API/EnsembleKalmanProcess.md @@ -4,14 +4,24 @@ CurrentModule = EnsembleKalmanProcesses ``` +## Primary objects and functions ```@docs EnsembleKalmanProcess -FailureHandler -SampleSuccGauss -IgnoreFailures +construct_initial_ensemble +update_ensemble! +``` +## Getter functions + +```@docs get_u get_g get_ϕ +get_obs(ekp::EnsembleKalmanProcess) +get_obs(ekp::EnsembleKalmanProcess, iteration) +get_obs_noise_cov(ekp::EnsembleKalmanProcess) +get_obs_noise_cov(ekp::EnsembleKalmanProcess, iteration) +get_obs_noise_cov_inv(ekp::EnsembleKalmanProcess) +get_obs_noise_cov_inv(ekp::EnsembleKalmanProcess, iteration) get_u_final get_g_final get_ϕ_final @@ -23,8 +33,6 @@ get_u_mean_final get_u_cov_final get_g_mean_final get_ϕ_mean_final -compute_error! -get_error get_N_iterations get_N_ens get_accelerator @@ -37,14 +45,25 @@ get_localizer get_localizer_type get_nan_tolerance get_nan_row_values -construct_initial_ensemble -update_ensemble! -sample_empirical_gaussian -split_indices_by_success -impute_over_nans list_update_groups_over_minibatch ``` +## [Error metrics](@id errors_api) +```@docs +compute_average_rmse +compute_loss_at_mean +compute_average_unweighted_rmse +compute_unweighted_loss_at_mean +compute_bayes_loss_at_mean +compute_crps +compute_error! +get_error_metrics +get_error +lmul_obs_noise_cov +lmul_obs_noise_cov_inv +lmul_obs_noise_cov! +lmul_obs_noise_cov_inv! +``` ## [Learning Rate Schedulers](@id scheduler_api) ```@docs @@ -54,7 +73,16 @@ EKSStableScheduler DataMisfitController calculate_timestep! ``` +## Failure and NaN handling +```@docs +FailureHandler +SampleSuccGauss +IgnoreFailures +sample_empirical_gaussian +split_indices_by_success +impute_over_nans +``` ## [Process-specific](@id process_api) ```@docs diff --git a/docs/src/API/Observations.md b/docs/src/API/Observations.md index 01f5f185d..6ebc79c79 100644 --- a/docs/src/API/Observations.md +++ b/docs/src/API/Observations.md @@ -45,13 +45,19 @@ get_minibatches(m::RFSM) where {RFSM <: RandomFixedSizeMinibatcher} ```@docs ObservationSeries get_observations(os::OS) where {OS <: ObservationSeries} +get_length_epoch(os::OS) where {OS <: ObservationSeries} get_minibatches(os::OS) where {OS <: ObservationSeries} +get_minibatch_index get_current_minibatch_index(os::OS) where {OS <: ObservationSeries} get_minibatcher(os::OS) where {OS <: ObservationSeries} get_metadata update_minibatch!(os::OS) where {OS <: ObservationSeries} +get_minibatch get_current_minibatch(os::OS) where {OS <: ObservationSeries} get_obs(os::OS) where {OS <: ObservationSeries} +get_obs(os::OS, iteration::IorN) where {OS <: ObservationSeries, IorN <: Union{Int, Nothing}} get_obs_noise_cov(os::OS) where {OS <: ObservationSeries} +get_obs_noise_cov(os::OS, iteration::IorN) where {OS <: ObservationSeries, IorN <: Union{Int, Nothing}} get_obs_noise_cov_inv(os::OS) where {OS <: ObservationSeries} +get_obs_noise_cov_inv(os::OS, iteration::IorN) where {OS <: ObservationSeries, IorN <: Union{Int, Nothing}} ``` \ No newline at end of file diff --git a/docs/src/API/Visualize.md b/docs/src/API/Visualize.md new file mode 100644 index 000000000..26c95aeb8 --- /dev/null +++ b/docs/src/API/Visualize.md @@ -0,0 +1,25 @@ +# Visualize + +!!! note "Add a Makie-backend package to your Project.toml" + Import one of the Makie backends (GLMakie, CairoMakie, WGLMakie, RPRMakie, + etc.) to enable these functions! + +```@meta +CurrentModule = EnsembleKalmanProcesses +``` + +```@docs +Visualize.plot_parameter_distribution +Visualize.plot_error_over_iters +Visualize.plot_error_over_iters! +Visualize.plot_error_over_time +Visualize.plot_error_over_time! +Visualize.plot_ϕ_over_iters +Visualize.plot_ϕ_over_iters! +Visualize.plot_ϕ_over_time +Visualize.plot_ϕ_over_time! +Visualize.plot_ϕ_mean_over_iters! +Visualize.plot_ϕ_mean_over_iters +Visualize.plot_ϕ_mean_over_time! +Visualize.plot_ϕ_mean_over_time +``` diff --git a/docs/src/visualization.md b/docs/src/visualization.md new file mode 100644 index 000000000..723464d09 --- /dev/null +++ b/docs/src/visualization.md @@ -0,0 +1,195 @@ +# Visualization + +We provide some simple plotting features to help users diagnose convergence in +the algorithm. Through + +- Visualizing priors +- Visualizing slices of parameter ensembles over iterations (or algorithm time) over input space +- Visualizing the loss function or other computed error metrics from [`compute_error!`](@ref) + +*The following documentation provides the overview of what is currently implemented for Plots or Makie backends, these will be expanded in due course.* + +## Plots.jl + +!!! note "Add Plots.jl to your Project.toml" + To enable plotting by Plots.jl, use `using Plots`. + +Plotting using Plots.jl supports only plotting distributions. See the example +below. + +```@example +using EnsembleKalmanProcesses.ParameterDistributions +prior_u1 = constrained_gaussian("positive_with_mean_2", 2, 1, 0, Inf) +prior_u2 = constrained_gaussian("four_with_spread_5", 0, 5, -Inf, Inf, repeats=4) +prior = combine_distributions([prior_u1, prior_u2]) + +using Plots +p = plot(prior) +``` + +## Makie.jl + +!!! note "Add a Makie-backend package to your Project.toml" + Import one of the Makie backends (GLMakie, CairoMakie, WGLMakie, RPRMakie, + etc.) to enable these functions! + +Plotting functionality is provided by Makie.jl through a package extension. See +the [documentation](@ref Visualize) for a list of all the available plotting +functions. The plotting functions have the same signature that one would expect +from Makie plotting functions. See the +[plot methods section](https://docs.makie.org/dev/explanations/plot_method_signatures) +in Makie documentation for more information. + +```@setup makie_plots +# Fix random seed, so plots don't change +import Random +rng_seed = 1234 +rng = Random.MersenneTwister(rng_seed) + +using LinearAlgebra +using EnsembleKalmanProcesses +using EnsembleKalmanProcesses.ParameterDistributions + +G(u) = + [1.0 / abs(u[1]), sum(u[2:3]), u[3], u[1]^2 - u[2] - u[3], u[1], 5.0] .+ 0.1 * randn(6) +true_u = [3, 1, 2] +y = G(true_u) +Γ = (0.1)^2 * I + +prior_u1 = constrained_gaussian("positive_with_mean_1", 2, 1, 0, Inf) +prior_u2 = constrained_gaussian("two_with_spread_2", 0, 5, -Inf, Inf, repeats = 2) +prior = combine_distributions([prior_u1, prior_u2]) + +N_ensemble = 30 +initial_ensemble = construct_initial_ensemble(prior, N_ensemble) +ekp = EnsembleKalmanProcess(initial_ensemble, y, Γ, Inversion(), verbose = true) + +N_iterations = 10 +for i = 1:N_iterations + params_i = get_ϕ_final(prior, ekp) + + G_matrix = hcat( + [G(params_i[:, i]) for i = 1:N_ensemble]..., # Parallelize here! + ) + + update_ensemble!(ekp, G_matrix) +end +``` + +### Plot priors + +The function `plot_parameter_distribution` will plot the marginal histograms for +all dimensions of the parameter distribution. + +```@example makie_plots +using CairoMakie # load a Makie backend +import EnsembleKalmanProcesses.Visualize as viz + +fig_priors = CairoMakie.Figure(size = (500, 400)) +viz.plot_parameter_distribution(fig_priors[1, 1], prior) + +fig_priors +``` + +### Plot errors + +The functions `plot_error_over_iters` and `plot_error_over_time` and the +mutating versions can be used to plot the errors over the number of iterations +or time respectively. All keyword arguments supported by +[`Makie.lines`](https://docs.makie.org/dev/reference/plots/lines) are supported +by these functions too. + +Any of the stored error metrics, computed by EnsembleKalmanProcess can be +plotted by the `error_metric="metric-name"` keyword argument, . See +[`compute_error!`](@ref) for a list of the computed error metrics and their +names. + +```@example makie_plots +# Assume that ekp is the EnsembleKalmanProcess object +using CairoMakie # load a Makie backend +import EnsembleKalmanProcesses.Visualize as viz + +fig_errors = CairoMakie.Figure(size = (300 * 2, 300 * 1)) +viz.plot_error_over_iters(fig_errors[1, 1], ekp, color = :tomato) +# Error plotting functions support plotting different errors through the +# error_metric keyword argument +ax1 = CairoMakie.Axis(fig_errors[1, 2], title = "Average RMSE over iterations", yscale = Makie.pseudolog10) +viz.plot_error_over_time!( + ax1, + ekp, + linestyle = :dashdotdot, + error_metric = "avg_rmse", +) +fig_errors +``` + +### Plot constrained parameters + +#### Plot by a slice of the parameter + +The functions `plot_ϕ_over_iters` and `plot_ϕ_over_time` and the mutating +versions can be used to plot a scatter plot of an individual parameter (dimension) over +the number of iterations or time respectively. All keyword arguments supported +by [`Makie.Scatter`](https://docs.makie.org/dev/reference/plots/scatter) are +supported by these functions too. + +Furthermore, there are also `plot_ϕ_mean_over_iters` and `plot_ϕ_mean_over_time` +and the mutating versions can be used to plot the mean parameter (dimension) +over the number of iterations or time respectively. All keyword arguments +supported by [`Makie.Scatter`](https://docs.makie.org/dev/reference/plots/lines) +are supported by these functions too. Furthermore, there is the keyword argument +`plot_std = true` which can be used to also plot the standard deviation as well. +To support plotting both mean and standard deviation, there are also the keyword +arguments `line_kwargs` and `band_kwargs` for adjusting how the mean and +standard deviation are plotted. Any keyword arguments accepted by +[`Makie.Lines`](https://docs.makie.org/dev/reference/plots/lines) can be +passed to `line_kwargs` and any keyword arguments accepted by +[`Makie.Band`](https://docs.makie.org/dev/reference/plots/band) can be passed to +`band_kwargs`. + +```@example makie_plots +using CairoMakie # load a Makie backend +import EnsembleKalmanProcesses.Visualize as viz + +fig_ϕ = CairoMakie.Figure(size = (300 * 2, 300 * 2)) +EnsembleKalmanProcesses.Visualize.plot_ϕ_over_time(fig_ϕ[1, 1], ekp, prior, 1, axis = (xscale = Makie.pseudolog10,)) +ax1 = CairoMakie.Axis(fig_ϕ[1, 2]) +viz.plot_ϕ_mean_over_iters!(ax1, ekp, prior, 2) +viz.plot_ϕ_mean_over_iters( + fig_ϕ[2, 1], + ekp, + prior, + 3, + plot_std = true, + line_kwargs = (linestyle = :dash,), + band_kwargs = (alpha = 0.2,), + color = :purple, + ) +fig_ϕ +``` + +#### Plot by distribution name + +It is also possible to plot the constrained (possibly-multidimensional) +parameters by name. Note that a gridposition per parameter-dimension must be +provided for plotting. See the example below. + +```@example makie_plots +using CairoMakie # load a Makie backend +import EnsembleKalmanProcesses.Visualize as viz + +fig_ϕ = CairoMakie.Figure(size = (300 * 2, 300 * 2)) +viz.plot_ϕ_over_time( + [fig_ϕ[1, 1], fig_ϕ[1, 2]], + ekp, + prior, + "two_with_spread_2", + linewidth = 1.5) +viz.plot_ϕ_mean_over_time( + [fig_ϕ[2, 1], fig_ϕ[2, 2]], + ekp, + prior, + "two_with_spread_2", + linewidth = 3.0) +fig_ϕ +``` diff --git a/ext/EnsembleKalmanProcessesMakieExt.jl b/ext/EnsembleKalmanProcessesMakieExt.jl new file mode 100644 index 000000000..1cdbcc0e7 --- /dev/null +++ b/ext/EnsembleKalmanProcessesMakieExt.jl @@ -0,0 +1,704 @@ +module EnsembleKalmanProcessesMakieExt + +using Makie + +import EnsembleKalmanProcesses.Visualize as Visualize +import EnsembleKalmanProcesses: + EnsembleKalmanProcess, + get_N_iterations, + get_Δt, + get_algorithm_time, + get_error_metrics, + get_ϕ, + get_ϕ_mean, + get_process, + get_prior_mean, + get_prior_cov +using EnsembleKalmanProcesses.ParameterDistributions +import Statistics: std + +using Random + +# Plotting recipe does not support plotting multiple plots, so this is a convience wrapper that +# plot the priors with no support for passing in keyword arguments +""" + Visualize.plot_parameter_distribution(fig::Union{Makie.Figure, Makie.GridLayout, Makie.GridPosition, Makie.GridSubposition}, + pd::ParameterDistribution; + constrained = true, + n_sample = 1e4, + rng = Random.GLOBAL_RNG) + +Plot the distributions `pd` on `fig`. +""" +function Visualize.plot_parameter_distribution( + fig::Union{Makie.Figure, Makie.GridLayout, Makie.GridPosition, Makie.GridSubposition}, + pd::ParameterDistribution; + constrained = true, + n_sample = 1e4, + rng = Random.GLOBAL_RNG, +) + samples = sample(rng, pd, Int(n_sample)) + if constrained + samples = transform_unconstrained_to_constrained(pd, samples) + end + n_plots = ndims(pd) + batches = batch(pd) + + # Determine the number of cols for the figure + cols = Int(ceil(sqrt(n_plots))) + + for i in 1:n_plots + r = div(i - 1, cols) + 1 + c = (i - 1) % cols + 1 + batch_id = [j for j = 1:length(batches) if i ∈ batches[j]][1] + dim_in_batch = i - minimum(batches[batch_id]) + 1 # i.e. if i=5 in batch 3:6, this would be "3" + ax = Makie.Axis( + fig[r, c], + title = pd.name[batch_id] * " (dim " * string(dim_in_batch) * ")", + xgridvisible = false, + ygridvisible = false, + ) + Makie.ylims!(ax, 0.0, nothing) + Makie.hist!( + ax, + samples[i, :], + normalization = :pdf, + color = Makie.Cycled(batch_id), + strokecolor = (:black, 0.7), + strokewidth = 0.7, + bins = Int(floor(sqrt(n_sample))), + ) + end + return nothing +end + +""" + Visualize.plot_parameter_distribution(fig::Union{Makie.Figure, Makie.GridLayout}, + pd::PDT; + constrained = true, + n_sample = 1e4, + rng = Random.GLOBAL_RNG) + where {PDT <: ParameterDistributionType} + +Plot the distribution on `fig`. +""" +function Visualize.plot_parameter_distribution( + fig::Union{Makie.Figure, Makie.GridLayout, Makie.GridPosition, Makie.GridSubposition}, + d::PDT; + n_sample = 1e4, + rng = Random.GLOBAL_RNG, +) where {PDT <: ParameterDistributionType} + samples = sample(rng, pd, Int(n_sample)) + + n_plots = ndims(d) + + # Determine the number of cols for the figure + cols = Int(ceil(sqrt(n_plots))) + + for i in 1:n_plots + r = div(i - 1, cols) + 1 + c = (i - 1) % cols + 1 + batch_id = [j for j = 1:length(batches) if i ∈ batches[j]][1] + # i.e. if i=5 in batch 3:6, this would be "3" + dim_in_batch = i - minimum(batches[batch_id]) + 1 + ax = Makie.Axis( + fig[r, c], + title = pd.name[batch_id] * " (dim " * string(dim_in_batch) * ")", + xgridvisible = false, + ygridvisible = false, + ) + Makie.ylims!(ax, 0.0, nothing) + Makie.hist!( + ax, + samples[i, :], + normalization = :pdf, + color = Makie.Cycled(batch_id), + strokecolor = (:black, 0.7), + strokewidth = 0.7, + bins = Int(floor(sqrt(n_sample))), + ) + end + return nothing +end + +# Define plot recipe for iterations or time vs errors line plots +# Define the function erroranditersortime whose functionality is +# implemented by specializing +# Makie.plot!(plot::ConstrainedParamsAndItersOrTime) +@recipe(ErrorAndItersOrTime, ekp) do scene + Attributes(linewidth = 4.5) +end + +""" + Visualize.plot_error_over_iters(gridposition, ekp; kwargs...) + +Plot the errors from `ekp` against the number of iterations on `gridposition`. + +Any keyword arguments is passed to the plotting function which takes in any +keyword arguments supported by +[`Makie.Lines`](https://docs.makie.org/dev/reference/plots/lines). +""" +Visualize.plot_error_over_iters(args...; kwargs...) = _plot_error_over_iters(erroranditersortime, args...; kwargs...) + +""" + Visualize.plot_error_over_iters!(axis, ekp; kwargs...) + +Plot the errors from `ekp` against the number of iterations on `axis`. + +Any keyword arguments is passed to the plotting function which takes in any +keyword arguments supported by +[`Makie.Lines`](https://docs.makie.org/dev/reference/plots/lines). +""" +Visualize.plot_error_over_iters!(args...; kwargs...) = _plot_error_over_iters(erroranditersortime!, args...; kwargs...) + +""" + _plot_errors_over_iters(plot_fn, args...; kwargs...) + +Helper function to plot the errors against the number of iterations. + +The arguments one can expect is `fig_or_ax` and `ekp` or the last argument. +""" +function _plot_error_over_iters(plot_fn, args...; kwargs...) + ekp = last(args) + iters = 1:get_N_iterations(ekp) + error_metric = _find_appropriate_error(ekp, get(kwargs, :error_metric, nothing)) + plot = plot_fn(args...; xvals = iters, error_metric = error_metric, kwargs...) + + # If an axis is passed in, it must be the first argument. Otherwise, get the + # current axis as determined by Makie. Note that current_axis() does not work + # if an axis is passed in. + ax = typeof(first(args)) <: Makie.AbstractAxis ? first(args) : Makie.current_axis() + _modify_axis!( + ax; + xlabel = "Iterations", + title = "Error over iterations", + ylabel = "$error_metric", + xticks = 1:get_N_iterations(ekp), + ) + return plot +end + +""" + Visualize.plot_error_over_time(gridposition, ekp; kwargs...) + +Plot the errors from `ekp` against time on `gridposition`. + +Any keyword arguments is passed to the plotting function which takes in any +keyword arguments supported by +[`Makie.Lines`](https://docs.makie.org/dev/reference/plots/lines). +""" +Visualize.plot_error_over_time(args...; kwargs...) = _plot_error_over_time(erroranditersortime, args...; kwargs...) + +""" + Visualize.plot_error_over_time!(axis, ekp; kwargs...) + +Plot the errors from `ekp` against time on `axis`. + +Any keyword arguments is passed to the plotting function which takes in any +keyword arguments supported by +[`Makie.Lines`](https://docs.makie.org/dev/reference/plots/lines). +""" +Visualize.plot_error_over_time!(args...; kwargs...) = _plot_error_over_time(erroranditersortime!, args...; kwargs...) + +""" + _plot_error_over_time(plot_fn, fig_or_ax, ekp; kwargs...) + +Helper function to plot errors over time. + +The arguments one can expect is `fig_or_ax` and `ekp` or the last argument. +""" +function _plot_error_over_time(plot_fn, args...; kwargs...) + ekp = last(args) + times = get_algorithm_time(ekp) + error_metric = _find_appropriate_error(ekp, get(kwargs, :error_metric, nothing)) + plot = plot_fn(args...; xvals = times, error_metric = error_metric, kwargs...) + + # If an axis is passed in, it must be the first argument. Otherwise, get the + # current axis as determined by Makie. Note that current_axis() does not work + # if an axis is passed in. + ax = typeof(first(args)) <: Makie.AbstractAxis ? first(args) : Makie.current_axis() + _modify_axis!(ax; xlabel = "Time", title = "Error over time", ylabel = "$error_metric") + return plot +end + +""" + Makie.plot!(plot::ErrorAndItersOrTime) + +Plot the errors against iteration or time. +""" +function Makie.plot!(plot::ErrorAndItersOrTime) + xvals = plot.xvals[] + errors = get_error_metrics(plot.ekp[])[plot.error_metric[]] + valid_attributes = Makie.shared_attributes(plot, Makie.Lines) + Makie.lines!(plot, xvals, errors; valid_attributes...) + return plot +end + +# Define plot recipe for iteration or time vs constrained parameters scatter plots +# Define the function constrainedparamsanditersortime whose functionality is +# implemented by specializing Makie.plot!(plot::ConstrainedParamsAndItersOrTime) +@recipe(ConstrainedParamsAndItersOrTime, ekp, prior, dim_idx) do scene + Theme() +end + +""" + Visualize.plot_ϕ_over_iters(gridposition, ekp, prior, dim_idx; kwargs...) + +Plot the constrained parameter of index `dim_idx` against time on `gridposition`. + +Any keyword arguments is passed to the plotting function which takes in any +keyword arguments supported by +(`Makie.Scatter`)[https://docs.makie.org/dev/reference/plots/scatter]. +""" +Visualize.plot_ϕ_over_iters(args...; kwargs...) = + _plot_ϕ_over_iters(constrainedparamsanditersortime, args...; kwargs...) + +""" + Visualize.plot_ϕ_over_iters!(axis, ekp, prior, dim_idx; kwargs...) + +Plot the constrained parameter of index `dim_idx` against time on `axis`. + +Any keyword arguments is passed to the plotting function which takes in any +keyword arguments supported by +(`Makie.Scatter`)[https://docs.makie.org/dev/reference/plots/scatter]. +""" +Visualize.plot_ϕ_over_iters!(args...; kwargs...) = + _plot_ϕ_over_iters(constrainedparamsanditersortime!, args...; kwargs...) + +""" + _plot_ϕ_over_iters(plot_fn, args...; kwargs...) + +Helper function to plot constrained parameters against the number of iterations. + +The arguments one can expect is `fig_or_ax`, `ekp`, `prior`, and `dim_idx` or +the last three arguments. +""" +function _plot_ϕ_over_iters(plot_fn, args...; kwargs...) + ekp, prior, dim_idx = last(args, 3) + iters = collect(0:get_N_iterations(ekp)) + plot = plot_fn(args...; xvals = iters, kwargs...) + + # If an axis is passed in, it must be the first argument. Otherwise, get the + # current axis as determined by Makie. Note that current_axis() does not work + # if an axis is passed in. + ax = typeof(first(args)) <: Makie.AbstractAxis ? first(args) : Makie.current_axis() + _modify_axis!( + ax; + xlabel = "Iterations", + title = "ϕ over iterations [dim $(_get_dim_of_dist(prior, dim_idx))]", + ylabel = "$(_get_prior_name(prior, dim_idx))", + xticks = 0:get_N_iterations(ekp), + ) + return plot +end + +""" + Visualize.plot_ϕ_over_time(gridposition, ekp, prior, dim_idx; kwargs...) + +Plot the constrained parameter of index `dim_idx` against time on `gridposition`. + +Any keyword arguments is passed to the plotting function which takes in any +keyword arguments supported by +(`Makie.Scatter`)[https://docs.makie.org/dev/reference/plots/scatter]. +""" +Visualize.plot_ϕ_over_time(args...; kwargs...) = _plot_ϕ_over_time(constrainedparamsanditersortime, args...; kwargs...) + +""" + Visualize.plot_ϕ_over_time!(axis, ekp, prior, dim_idx; kwargs...) + +Plot the constrained parameter of index `dim_idx` against time on `axis`. + +Any keyword arguments is passed to the plotting function which takes in any +keyword arguments supported by +(`Makie.Scatter`)[https://docs.makie.org/dev/reference/plots/scatter]. +""" +Visualize.plot_ϕ_over_time!(args...; kwargs...) = + _plot_ϕ_over_time(constrainedparamsanditersortime!, args...; kwargs...) + +""" + _plot_ϕ_over_time(plot_fn, args...; kwargs...) + +Helper function to plot constrained parameter over iterations. + +The arguments one can expect is `fig_or_ax`, `ekp`, `prior`, and `dim_idx` or +the last three arguments. +""" +function _plot_ϕ_over_time(plot_fn, args...; kwargs...) + ekp, prior, dim_idx = last(args, 3) + times = get_algorithm_time(ekp) + pushfirst!(times, zero(eltype(times))) + plot = plot_fn(args...; xvals = times, kwargs...) + + # If an axis is passed in, it must be the first argument. Otherwise, get the + # current axis as determined by Makie. Note that current_axis() does not work + # if an axis is passed in. + ax = typeof(first(args)) <: Makie.AbstractAxis ? first(args) : Makie.current_axis() + _modify_axis!( + ax; + xlabel = "Time", + title = "ϕ over time [dim $(_get_dim_of_dist(prior, dim_idx))]", + ylabel = "$(_get_prior_name(prior, dim_idx))", + ) + return plot +end + +""" + Makie.plot!(plot::ConstrainedParamsAndItersOrTime) + +Plot the constrained parameter versus iterations or time. +""" +function Makie.plot!(plot::ConstrainedParamsAndItersOrTime) + ekp, prior, dim_idx, xvals = plot.ekp[], plot.prior[], plot.dim_idx[], plot.xvals[] + ϕ = get_ϕ(prior, ekp) + ensemble_size = size(first(ϕ))[2] + nums = vcat((xval * ones(ensemble_size) for xval in xvals)...) + param_nums = vcat((ϕ[idx][dim_idx, :] for idx in eachindex(xvals))...) + valid_attributes = Makie.shared_attributes(plot, Makie.Scatter) + Makie.scatter!(plot, nums, param_nums; valid_attributes...) + return plot +end + +# Define plot recipe for iteration or time vs constrained parameters scatter +# plots +# Define the function constrainedmeanparamsanditersortime whose functionality is +# implemented by specializing +# Makie.plot!(plot::ConstrainedMeanParamsAndItersOrTime) +@recipe(ConstrainedMeanParamsAndItersOrTime, ekp, prior, dim_idx) do scene + Attributes(linewidth = 4.5) +end + +""" + Visualize.plot_ϕ_mean_over_iters!(axis, ekp, prior, dim_idx; plot_std = false, kwargs...) + +Plot the mean constrained parameter of index `dim_idx` of `prior` against the +number of iterations on `axis`. + +If `plot_std = true`, then the standard deviation of the constrained parameters +of the ensemble is also plotted. + +Any keyword arguments is passed to the plotting function which takes in any +keyword arguments supported by +[`Makie.Lines`](https://docs.makie.org/dev/reference/plots/lines) and +(`Makie.Band`)[https://docs.makie.org/dev/reference/plots/band] if +`plot_std = true`. + +Keyword arguments passed to `line_kwargs` and `band_kwargs` are merged with +`kwargs` when possible. The keyword arguments in `line_kwargs` and +`band_kwargs` take priority over the keyword arguments in `kwargs`. +""" +Visualize.plot_ϕ_mean_over_iters!(args...; plot_std = false, kwargs...) = + _plot_ϕ_mean_over_iters(constrainedmeanparamsanditersortime!, args...; plot_std, kwargs...) + +""" + Visualize.plot_ϕ_mean_over_iters(gridposition, ekp, prior, dim_idx; plot_std = false, kwargs...) + +Plot the mean constrained parameter of index `dim_idx` of `prior` against the +number of iterations on `gridposition`. + +If `plot_std = true`, then the standard deviation of the constrained parameters +of the ensemble is also plotted. + +Any keyword arguments is passed to the plotting function which takes in any +keyword arguments supported by +[`Makie.Lines`](https://docs.makie.org/dev/reference/plots/lines) and +(`Makie.Band`)[https://docs.makie.org/dev/reference/plots/band] if +`plot_std = true`. + +Keyword arguments passed to `line_kwargs` and `band_kwargs` are merged with +`kwargs` when possible. The keyword arguments in `line_kwargs` and +`band_kwargs` take priority over the keyword arguments in `kwargs`. +""" +Visualize.plot_ϕ_mean_over_iters(args...; plot_std = false, kwargs...) = + _plot_ϕ_mean_over_iters(constrainedmeanparamsanditersortime, args...; plot_std, kwargs...) + +""" + _plot_ϕ_mean_over_iters(plot_fn, args...; plot_std = false, kwargs...) + +Helper function to plot mean constrained parameter against the number of +iterations. + +The arguments one can expect is `fig_or_ax`, `ekp`, `prior`, and `dim_idx` or +the last three arguments. +""" +function _plot_ϕ_mean_over_iters(plot_fn, args...; plot_std, kwargs...) + ekp, prior, dim_idx = last(args, 3) + iters = collect(0:get_N_iterations(ekp)) + plot = plot_fn(args...; plot_std, xvals = iters, kwargs...) + + # If an axis is passed in, it must be the first argument. Otherwise, get the + # current axis as determined by Makie. Note that current_axis() does not work + # if an axis is passed in. + ax = typeof(first(args)) <: Makie.AbstractAxis ? first(args) : Makie.current_axis() + if plot_std + axis_attribs = ( + xlabel = "Iterations", + title = "Mean and std of ϕ over iterations [dim $(_get_dim_of_dist(prior, dim_idx))]", + ylabel = "$(_get_prior_name(prior, dim_idx))", + xticks = 0:get_N_iterations(ekp), + ) + else + axis_attribs = ( + xlabel = "Iterations", + title = "Mean ϕ over iterations [dim $(_get_dim_of_dist(prior, dim_idx))]", + ylabel = "$(_get_prior_name(prior, dim_idx))", + xticks = 0:get_N_iterations(ekp), + ) + end + _modify_axis!(ax; axis_attribs...) + return plot +end + +""" + Visualize.plot_ϕ_mean_over_time!(axis, ekp, prior, dim_idx; plot_std = false, kwargs...) + +Plot the mean constrained parameter of index `dim_idx` of `prior` against time +on `axis`. + +If `plot_std = true`, then the standard deviation of the constrained parameter +of the ensemble is also plotted. + +Any keyword arguments is passed to the plotting function which takes in any +keyword arguments supported by +[`Makie.Lines`](https://docs.makie.org/dev/reference/plots/lines) and +(`Makie.Band`)[https://docs.makie.org/dev/reference/plots/band] if +`plot_std = true`. + +Keyword arguments passed to `line_kwargs` and `band_kwargs` are merged with +`kwargs` when possible. The keyword arguments in `line_kwargs` and +`band_kwargs` take priority over the keyword arguments in `kwargs`. +""" +Visualize.plot_ϕ_mean_over_time!(args...; plot_std = false, kwargs...) = + _plot_ϕ_mean_over_time(constrainedmeanparamsanditersortime!, args...; plot_std, kwargs...) + +""" + Visualize.plot_ϕ_mean_over_time(gridposition, ekp, prior, dim_idx; plot_std = false, kwargs...) + +Plot the mean constrained parameter of index `dim_idx` of `prior` against time +on `gridposition`. + +If `plot_std = true`, then the standard deviation of the constrained parameter +of the ensemble is also plotted. + +Any keyword arguments is passed to the plotting function which takes in any +keyword arguments supported by +[`Makie.Lines`](https://docs.makie.org/dev/reference/plots/lines) and +(`Makie.Band`)[https://docs.makie.org/dev/reference/plots/band] if +`plot_std = true`. + +Keyword arguments passed to `line_kwargs` and `band_kwargs` are merged with +`kwargs` when possible. The keyword arguments in `line_kwargs` and +`band_kwargs` take priority over the keyword arguments in `kwargs`. +""" +Visualize.plot_ϕ_mean_over_time(args...; plot_std = false, kwargs...) = + _plot_ϕ_mean_over_time(constrainedmeanparamsanditersortime, args...; plot_std, kwargs...) + +""" + _plot_ϕ_mean_over_time(plot_fn, fig_or_ax, ekp, prior, dim_idx; kwargs...) + +Helper function to plot mean constrained parameters against time. + +The arguments one can expect is `fig_or_ax`, `ekp`, `prior`, and `dim_idx` or +the last three arguments. +""" +function _plot_ϕ_mean_over_time(plot_fn, args...; plot_std, kwargs...) + ekp, prior, dim_idx = last(args, 3) + times = accumulate(+, get_Δt(ekp)) + pushfirst!(times, zero(eltype(times))) + plot = plot_fn(args...; plot_std, xvals = times, kwargs...) + + # If an axis is passed in, it must be the first argument. Otherwise, get the + # current axis as determined by Makie. Note that current_axis() does not work + # if an axis is passed in. + ax = typeof(first(args)) <: Makie.AbstractAxis ? first(args) : Makie.current_axis() + if plot_std + axis_attribs = ( + xlabel = "Time", + title = "Mean and std of ϕ over time [dim $(_get_dim_of_dist(prior, dim_idx))]", + ylabel = "$(_get_prior_name(prior, dim_idx))", + ) + else + axis_attribs = ( + xlabel = "Time", + title = "Mean ϕ over time [dim $(_get_dim_of_dist(prior, dim_idx))]", + ylabel = "$(_get_prior_name(prior, dim_idx))", + ) + end + _modify_axis!(ax; axis_attribs...) + return plot +end + +""" + Makie.plot!(plot::ConstrainedMeanParamsAndItersOrTime) + +Plot the mean constrained parameter versus iterations or time. +""" +function Makie.plot!(plot::ConstrainedMeanParamsAndItersOrTime) + ekp, prior, dim_idx, xvals = plot.ekp[], plot.prior[], plot.dim_idx[], plot.xvals[] + ϕ_mean_vals = [get_ϕ_mean(prior, ekp, iter)[dim_idx] for iter in 1:length(xvals)] + + valid_attributes = Makie.shared_attributes(plot, Makie.Lines) + hasproperty(plot, :line_kwargs) && (valid_attributes = merge(plot.line_kwargs, valid_attributes)) + Makie.lines!(plot, xvals, ϕ_mean_vals; valid_attributes...) + + if plot.plot_std[] + stds = [std(get_ϕ(prior, ekp, iter)[dim_idx, :]) for iter in 1:length(xvals)] + + valid_attributes = Makie.shared_attributes(plot, Makie.Band) + hasproperty(plot, :band_kwargs) && (valid_attributes = merge(plot.band_kwargs, valid_attributes)) + Makie.band!(plot, xvals, ϕ_mean_vals .- stds, ϕ_mean_vals .+ stds; valid_attributes...) + end + return plot +end + +""" + Visualize.plot_ϕ_over_iters(gridpositions, ekp, prior, name; kwargs...) + +Plot the constrained parameter belonging to distribution `name` against the +number of iterations on the iterable `gridpositions`. + +Any keyword arguments are passed to all plots. +""" +Visualize.plot_ϕ_over_iters( + gridpositions, + ekp::EnsembleKalmanProcess, + prior::ParameterDistribution, + name::AbstractString; + kwargs..., +) = _plot_ϕ(Visualize.plot_ϕ_over_iters, gridpositions, ekp, prior, name; kwargs...) + +""" + Visualize.plot_ϕ_over_time(gridpositions, ekp, prior, name; kwargs...) + +Plot the constrained parameter belonging to distribution `name` against time on +the iterable `gridpositions`. + +Any keyword arguments are passed to all plots. +""" +Visualize.plot_ϕ_over_time( + gridpositions, + ekp::EnsembleKalmanProcess, + prior::ParameterDistribution, + name::AbstractString; + kwargs..., +) = _plot_ϕ(Visualize.plot_ϕ_over_time, gridpositions, ekp, prior, name; kwargs...) + +""" + Visualize.plot_ϕ_mean_over_iters(gridpositions, ekp::EnsembleKalmanProcess, prior, name; kwargs...) + +Plot the mean constrained parameter belonging to distribution `name` against the +number of iterations on the iterable `gridpositions`. + +Any keyword arguments are passed to all plots. +""" +Visualize.plot_ϕ_mean_over_iters( + gridpositions, + ekp::EnsembleKalmanProcess, + prior::ParameterDistribution, + name::AbstractString; + kwargs..., +) = _plot_ϕ(Visualize.plot_ϕ_mean_over_iters, gridpositions, ekp, prior, name; kwargs...) + +""" + Visualize.plot_ϕ_mean_over_time(gridpositions, ekp, prior, name; kwargs...) + +Plot the mean constrained parameter belonging to distribution `name` against +time on the iterable `gridpositions`. + +Any keyword arguments are passed to all plots. +""" +Visualize.plot_ϕ_mean_over_time( + gridpositions, + ekp::EnsembleKalmanProcess, + prior::ParameterDistribution, + name::AbstractString; + kwargs..., +) = _plot_ϕ(Visualize.plot_ϕ_mean_over_time, gridpositions, ekp, prior, name; kwargs...) + +""" + _plot_ϕ(plot_fn, gridpositions, ekp, prior, name; kwargs...) + +Helper function to plot `ϕ`-related quantities by the name of the distribution. +""" +function _plot_ϕ(plot_fn, gridpositions, ekp, prior, name::AbstractString; kwargs...) + batch_idx = findfirst(x -> x == name, prior.name) + isnothing(batch_idx) && + error("Cannot find a distribution with the name $name; the available distributions are $(prior.name)") + dim_indices = batch(prior)[batch_idx] + length(dim_indices) == length(gridpositions) || error( + "The $name distribution has $(length(dim_indices)) dimensions, but $(length(gridpositions)) gridpositions are passed in", + ) + local plot + for (gridposition, dim_idx) in zip(gridpositions, dim_indices) + plot = plot_fn(gridposition, ekp, prior, dim_idx; kwargs...) + end + return plot +end + +""" + _modify_axis!(ax; kwargs...) + +Modify an axis in place. + +This currently supports only `:title`, `:xlabel`, `:ylabel`, and `:xticks`. + +This function is hacky as Makie does not currently have support for axis hints, +so this is a workaround which will work in most cases. + +For example, it is not possible to tell if the user want an empty title or not +since this function assumes that if the default is used, then it is okay to +overwrite. +""" +function _modify_axis!(ax; kwargs...) + defaults = Dict(:title => "", :xlabel => "", :ylabel => "", :xticks => Makie.Automatic()) + for (attrib, val) in kwargs + if attrib in keys(defaults) && getproperty(ax, attrib)[] == defaults[attrib] + getproperty(ax, attrib)[] = val + end + end + return nothing +end + +""" + _get_dim_of_dist(prior, dim_idx) + +Get the dimension of the distribution belonging to `dim_idx` in `prior`. +""" +function _get_dim_of_dist(prior, dim_idx) + for param_arr in batch(prior) + if dim_idx in param_arr + return findfirst(x -> x == dim_idx, param_arr) + end + end + error("Could not find name corresponding to parameter $dim_idx") +end + +""" + _get_prior_name(prior, dim_idx) + +Get name of `dim_idx` parameter from `prior`. +""" +function _get_prior_name(prior, dim_idx) + for (i, param_arr) in enumerate(batch(prior)) + if dim_idx in param_arr + return get_name(prior)[i] + end + end + error("Could not find name corresponding to parameter $dim_idx") +end +""" + _find_appropriate_error(ekp, error_name) + +Find the appropriate name of the error to return if `error_name` is `nothing`. +Otherwise, return `error_name`. +""" +function _find_appropriate_error(ekp, error_name) + if isnothing(error_name) + process = get_process(ekp) + prior_mean = get_prior_mean(process) + prior_cov = get_prior_cov(process) + error_name = (isnothing(prior_mean) || isnothing(prior_cov)) ? "loss" : "bayes_loss" + end + return error_name +end + +end diff --git a/src/EnsembleKalmanProcess.jl b/src/EnsembleKalmanProcess.jl index 1d4f81fa9..bce3644b3 100644 --- a/src/EnsembleKalmanProcess.jl +++ b/src/EnsembleKalmanProcess.jl @@ -11,11 +11,26 @@ using DocStringExtensions export EnsembleKalmanProcess export get_u, get_g, get_ϕ export get_u_prior, get_u_final, get_g_final, get_ϕ_final -export get_N_iterations, get_error, get_cov_blocks +export get_N_iterations, get_error_metrics, get_error, get_cov_blocks +export compute_average_rmse, + compute_loss_at_mean, + compute_average_unweighted_rmse, + compute_unweighted_loss_at_mean, + compute_bayes_loss_at_mean, + compute_crps export get_u_mean, get_u_cov, get_g_mean, get_ϕ_mean export get_u_mean_final, get_u_cov_prior, get_u_cov_final, get_g_mean_final, get_ϕ_mean_final + export get_scheduler, - get_localizer, get_localizer_type, get_accelerator, get_rng, get_Δt, get_failure_handler, get_N_ens, get_process + get_localizer, + get_localizer_type, + get_accelerator, + get_rng, + get_Δt, + get_algorithm_time, + get_failure_handler, + get_N_ens, + get_process export get_nan_tolerance, get_nan_row_values export get_observation_series, get_obs, @@ -199,8 +214,8 @@ struct EnsembleKalmanProcess{ N_ens::IT "Array of stores for forward model outputs, each of size [`N_obs × N_ens`]" g::Array{DataContainer{FT}} - "vector of errors" - error::Vector{FT} + "Dict of error metric" + error_metrics::Dict "Scheduler to calculate the timestep size in each EK iteration" scheduler::LRS "accelerator object that informs EK update steps, stores additional state variables as needed" @@ -261,7 +276,7 @@ function EnsembleKalmanProcess( #store for model evaluations g = [] # error store - err = FT[] + err = Dict{String, Vector{FT}}() # timestep store Δt = FT[] @@ -601,12 +616,21 @@ end """ get_Δt(ekp::EnsembleKalmanProcess) -Return `Δt` field of EnsembleKalmanProcess. +Return `Δt` field of EnsembleKalmanProcess. """ function get_Δt(ekp::EnsembleKalmanProcess) return ekp.Δt end +""" +$(TYPEDSIGNATURES) + +Get the accumulated Δt over iterations, also known as algorithm- or pseudo-time. +""" +function get_algorithm_time(ekp::EnsembleKalmanProcess) + return accumulate(+, get_Δt(ekp)) +end + """ get_failure_handler(ekp::EnsembleKalmanProcess) Return `failure_handler` field of EnsembleKalmanProcess. @@ -717,23 +741,47 @@ function get_observation_series(ekp::EnsembleKalmanProcess) end """ - get_obs_noise_cov(ekp::EnsembleKalmanProcess; build=true) -convenience function to get the obs_noise_cov from the current batch in ObservationSeries +$(TYPEDSIGNATURES) + +Get the `obs_noise_cov` from the current batch in `ObservationSeries` build=false:, returns a vector of blocks, build=true: returns a block matrix, """ function get_obs_noise_cov(ekp::EnsembleKalmanProcess; build = true) - return get_obs_noise_cov(get_observation_series(ekp), build = build) + return get_obs_noise_cov(get_observation_series(ekp); build = build) +end + +""" +$(TYPEDSIGNATURES) + +Get the `obs_noise_cov` for `iteration` from the `ObservationSeries` +build=false:, returns a vector of blocks, +build=true: returns a block matrix, +""" +function get_obs_noise_cov(ekp::EnsembleKalmanProcess, iteration; build = true) + return get_obs_noise_cov(get_observation_series(ekp), iteration; build = build) end """ - get_obs_noise_cov_inv(ekp::EnsembleKalmanProcess; build=true) -convenience function to get the obs_noise_cov (inverse) from the current batch in ObservationSeries +$(TYPEDSIGNATURES) + +Get the `obs_noise_cov` inverse from the current batch in `ObservationSeries` build=false:, returns a vector of blocks, build=true: returns a block matrix, """ function get_obs_noise_cov_inv(ekp::EnsembleKalmanProcess; build = true) - return get_obs_noise_cov_inv(get_observation_series(ekp), build = build) + return get_obs_noise_cov_inv(get_observation_series(ekp); build = build) +end + +""" +$(TYPEDSIGNATURES) + +Get the obs_noise_cov inverse for `iteration` from the `ObservationSeries` +build=false:, returns a vector of blocks, +build=true: returns a block matrix, +""" +function get_obs_noise_cov_inv(ekp::EnsembleKalmanProcess, iteration; build = true) + return get_obs_noise_cov_inv(get_observation_series(ekp), iteration; build = build) end """ @@ -790,15 +838,28 @@ function lmul_obs_noise_cov_inv!( end """ - get_obs(ekp::EnsembleKalmanProcess; build=true) -Get the observation from the current batch in ObservationSeries +$(TYPEDSIGNATURES) + +Get the observation from the current batch in `ObservationSeries` build=false: returns a vector of vectors, build=true: returns a concatenated vector, """ function get_obs(ekp::EnsembleKalmanProcess; build = true) - return get_obs(get_observation_series(ekp), build = build) + return get_obs(get_observation_series(ekp); build = build) end +""" +$(TYPEDSIGNATURES) + +Get the observation for `iteration` from the `ObservationSeries` +build=false: returns a vector of vectors, +build=true: returns a concatenated vector, +""" +function get_obs(ekp::EnsembleKalmanProcess, iteration; build = true) + return get_obs(get_observation_series(ekp), iteration; build = build) +end + + """ update_minibatch!(ekp::EnsembleKalmanProcess) update to the next minibatch in the ObservationSeries @@ -829,27 +890,214 @@ end construct_initial_ensemble(prior::ParameterDistribution, N_ens::IT) where {IT <: Int} = construct_initial_ensemble(Random.GLOBAL_RNG, prior, N_ens) +# metrics of interest """ - compute_error!(ekp::EnsembleKalmanProcess) +$(TYPEDSIGNATURES) -Computes the covariance-weighted error of the mean forward model output, `(ḡ - y)'Γ_inv(ḡ - y)`. -The error is stored within the `EnsembleKalmanProcess`. +Computes the average unweighted error over the forward model ensemble, normalized by the dimension: `1/dim(y) * tr((g_i - y)' * (g_i - y))`. +The error is retrievable as `get_error_metrics(ekp)["unweighted_avg_rmse"]` """ -function compute_error!(ekp::EnsembleKalmanProcess) - mean_g = dropdims(mean(get_g_final(ekp), dims = 2), dims = 2) +function compute_average_unweighted_rmse(ekp::EnsembleKalmanProcess) + g = get_g_final(ekp) + succ_ens, _ = split_indices_by_success(g) + diff = get_obs(ekp) .- g[:, succ_ens] # column diff + dot_diff = 1.0 / size(g, 1) * [dot(d, d) for d in eachcol(diff)] # trace(diff'*diff) + if any(dot_diff .< 0) + throw( + ArgumentError( + "Found x'*Γ⁻¹*x < 0. This implies Γ⁻¹ not positive semi-definite, \n please modify Γ (a.k.a noise covariance of the observation) to ensure p.s.d. holds.", + ), + ) + end + + ens_rmse = sqrt.(dot_diff) # rmse for each ens member + avg_rmse = mean(ens_rmse) # average + return avg_rmse +end + +""" +$(TYPEDSIGNATURES) + +Computes the average covariance-weighted error over the forward model ensemble, normalized by the dimension: `1/dim(y) * tr((g_i - y)' * Γ⁻¹ * (g_i - y))`. +The error is retrievable as `get_error_metrics(ekp)["avg_rmse"]` +""" +function compute_average_rmse(ekp::EnsembleKalmanProcess) + g = get_g_final(ekp) + succ_ens, _ = split_indices_by_success(g) + diff = get_obs(ekp) .- g[:, succ_ens] # column diff + X = lmul_obs_noise_cov_inv(ekp, diff) + weight_diff = 1.0 / size(g, 1) * [dot(x, d) for (x, d) in zip(eachcol(diff), eachcol(X))] # trace(diff'*X) + if any(weight_diff .< 0) + throw( + ArgumentError( + "Found x'*Γ⁻¹*x < 0. This implies Γ⁻¹ not positive semi-definite, \n please modify Γ (a.k.a noise covariance of the observation) to ensure p.s.d. holds.", + ), + ) + end + ens_rmse = sqrt.(sum(weight_diff)) # rmse for each ens member + avg_rmse = mean(ens_rmse) # average + return avg_rmse +end + +""" +$(TYPEDSIGNATURES) + +Computes the covariance-weighted error of the mean of the forward model output, normalized by the dimension `1/dim(y) * (ḡ - y)' * Γ⁻¹ * (ḡ - y)`. +The error is retrievable as `get_error_metrics(ekp)["loss"]` or returned from `get_error(ekp)` if a prior is not provided to the process +""" +function compute_loss_at_mean(ekp::EnsembleKalmanProcess) + g = get_g_final(ekp) + succ_ens, _ = split_indices_by_success(g) + mean_g = dropdims(mean(g[:, succ_ens], dims = 2), dims = 2) diff = get_obs(ekp) - mean_g X = lmul_obs_noise_cov_inv(ekp, diff) - newerr = dot(diff, X) - push!(get_error(ekp), newerr) + newerr = 1.0 / size(g, 1) * dot(diff, X) + return newerr end """ - get_error(ekp::EnsembleKalmanProcess) +$(TYPEDSIGNATURES) + +Computes the unweighted error of the mean of the forward model output, normalized by the dimension `1/dim(y) * (ḡ - y)' * (ḡ - y)`. +The error is retrievable as `get_error_metrics(ekp)["unweighted_loss"]` +""" +function compute_unweighted_loss_at_mean(ekp::EnsembleKalmanProcess) + g = get_g_final(ekp) + succ_ens, _ = split_indices_by_success(g) + mean_g = dropdims(mean(g[:, succ_ens], dims = 2), dims = 2) + diff = get_obs(ekp) - mean_g + newerr = 1.0 / size(g, 1) * dot(diff, diff) + return newerr +end -Returns the mean forward model output error as a function of algorithmic time. """ -get_error(ekp::EnsembleKalmanProcess) = ekp.error +$(TYPEDSIGNATURES) + +Computes the bayes loss of the mean of the forward model output, normalized by dimensions `(1/dim(y)*dim(u)) * [(ḡ - y)' * Γ⁻¹ * (ḡ - y) + (̄u - m)' * C⁻¹ * (̄u - m)]`. +If the prior is not provided to the process on creation of EKP, then `m` and `C` are estimated from the initial ensemble. + +The error is retrievable as `get_error_metrics(ekp)["bayes_loss"]` or returned from `get_error(ekp)` if a prior is provided to the process +""" +function compute_bayes_loss_at_mean(ekp::EnsembleKalmanProcess) + process = get_process(ekp) + prior_mean = get_prior_mean(process) + prior_cov = get_prior_cov(process) + g = get_g_final(ekp) + misfit_at_mean = compute_loss_at_mean(ekp) + + # estimate from initial ensemble if we do not have access to them + # note initial ensemble = prior, for ekp algorithms where prior is not provided + if isnothing(prior_mean) + u_prior = get_u(ekp, 1) + prior_mean = mean(u_prior, dims = 2) + end + if isnothing(prior_cov) + u_prior = get_u(ekp, 1) + prior_cov = cov(u_prior, dims = 2) + if !isposdef(prior_cov) + prior_cov = posdef_correct(prior_cov) + end + end + u = get_u_mean_final(ekp) + udiff = reshape(u - prior_mean, :, 1) + prior_misfit_at_mean = 1.0 / length(u) * dot(udiff, inv(prior_cov) * udiff) + # indep of input and output size + return (1.0 / length(u)) * misfit_at_mean + (1.0 / size(g, 1)) * prior_misfit_at_mean + +end +""" +$(TYPEDSIGNATURES) + +Computes a Gaussian approximation of CRPS (continuous rank probability score) of the ensemble with the observation (performing through a whitening by C^GG, see e.g., Zheng, Sun, 2025, https://arxiv.org/abs/2410.09133). +""" +function compute_crps(ekp::EnsembleKalmanProcess) + g = get_g_final(ekp) + succ_ens, _ = split_indices_by_success(g) + g_ens = g[:, succ_ens] + g_mean = mean(g_ens, dims = 2) + diff = get_obs(ekp) - g_mean + + # get svd of the perturbations from samples + g_svd = tsvd_cov_from_samples(g_ens, quiet = true) # Note from this function, .S are evals of cov-g + if size(g_svd.U, 1) == size(g_svd.U, 2) # then work with Vt + white_diff = 1 ./ sqrt.(g_svd.S) .* g_svd.Vt * diff # ~N(0,I) + else + white_diff = 1 ./ sqrt.(g_svd.S) .* g_svd.U' * diff # ~N(0,I) + end + dist = Normal(0, 1) + indep_crps = white_diff .* (2 .* cdf.(dist, white_diff) .- 1) .+ 2 * pdf.(dist, white_diff) .- 1 ./ sqrt(π) + avg_crps = 1 ./ length(g_svd.S) * sum(sqrt.(g_svd.S) .* indep_crps) + return avg_crps +end + + + +""" +$(TYPEDSIGNATURES) + +Computes a variety of error metrics and stores this in `EnsembleKalmanProcess`. (retrievable with `get_error_metrics(ekp)`) +currently available: +- `avg_rmse` computed with `compute_average_rmse(ekp)` +- `loss` computed with `compute_loss_at_mean(ekp)` +- `unweighed_avg_rmse` computed with `compute_average_unweighted_rmse(ekp)` +- `unweighted_loss` computed with `compute_unweighted_loss_at_mean(ekp)` +- `bayes_loss` computed with `compute_bayes_loss_at_mean(ekp)` +- `crps` computed with `compute_crps(ekp)` + +""" +function compute_error!(ekp::EnsembleKalmanProcess) + + rmse = compute_average_rmse(ekp) + loss = compute_loss_at_mean(ekp) + un_rmse = compute_average_unweighted_rmse(ekp) + un_loss = compute_unweighted_loss_at_mean(ekp) + bayes_loss = compute_bayes_loss_at_mean(ekp) + crps = compute_crps(ekp) + + em = get_error_metrics(ekp) + if length(keys(em)) == 0 + em["avg_rmse"] = [rmse] + em["loss"] = [loss] + em["bayes_loss"] = [bayes_loss] + em["unweighted_avg_rmse"] = [un_rmse] + em["unweighted_loss"] = [un_loss] + em["crps"] = [crps] + else + push!(em["avg_rmse"], rmse) + push!(em["loss"], loss) + push!(em["bayes_loss"], bayes_loss) + push!(em["unweighted_avg_rmse"], un_rmse) + push!(em["unweighted_loss"], un_loss) + push!(em["crps"], crps) + end +end + +""" + get_error_metrics(ekp::EnsembleKalmanProcess) + +Returns the stored `error_metrics`, created with `compute_error!` +""" +get_error_metrics(ekp::EnsembleKalmanProcess) = ekp.error_metrics + +""" + get_error(ekp::EnsembleKalmanProcess) + +[For back compatability] Returns the relevant loss function that is minimized over EKP iterations. +- If prior provided to the process: `get_error_metrics(ekp)["bayes_loss"]`, the loss computed with `compute_bayes_loss_at_mean(ekp)` +- If prior not provided to the process: `get_error_metrics(ekp)["loss"]`, the loss computed with `compute_loss_at_mean(ekp)` +""" +function get_error(ekp::EnsembleKalmanProcess) + process = get_process(ekp) + prior_mean = get_prior_mean(process) + prior_cov = get_prior_cov(process) + # if prior provided, then return the bayesian loss. Else return the loss + if isnothing(prior_mean) || isnothing(prior_cov) + return get_error_metrics(ekp)["loss"] + else + return get_error_metrics(ekp)["bayes_loss"] + end +end """ sample_empirical_gaussian( diff --git a/src/EnsembleKalmanProcesses.jl b/src/EnsembleKalmanProcesses.jl index 998e7dd4e..57a85a4ef 100644 --- a/src/EnsembleKalmanProcesses.jl +++ b/src/EnsembleKalmanProcesses.jl @@ -14,4 +14,5 @@ include("EnsembleKalmanProcess.jl") # Plot recipes include("PlotRecipes.jl") +include("Visualize.jl") end # module diff --git a/src/EnsembleKalmanSampler.jl b/src/EnsembleKalmanSampler.jl index 35f0a455e..35a205d08 100644 --- a/src/EnsembleKalmanSampler.jl +++ b/src/EnsembleKalmanSampler.jl @@ -27,6 +27,9 @@ function Sampler(prior::ParameterDistribution) return Sampler{FT}(mean_prior, cov_prior) end +get_prior_mean(process::Sampler) = process.prior_mean +get_prior_cov(process::Sampler) = process.prior_cov + function FailureHandler(process::Sampler, method::IgnoreFailures) function failsafe_update(ekp, u, g, failed_ens) diff --git a/src/GaussNewtonKalmanInversion.jl b/src/GaussNewtonKalmanInversion.jl index 8f2ee4c2b..7af3b6dc6 100644 --- a/src/GaussNewtonKalmanInversion.jl +++ b/src/GaussNewtonKalmanInversion.jl @@ -20,6 +20,10 @@ function GaussNewtonInversion(prior::ParameterDistribution) return GaussNewtonInversion(mean_prior, cov_prior) end +get_prior_mean(process::GaussNewtonInversion) = process.prior_mean +get_prior_cov(process::GaussNewtonInversion) = process.prior_cov + + # failure handling function FailureHandler(process::GaussNewtonInversion, method::IgnoreFailures) failsafe_update(ekp, u, g, y, obs_noise_cov, failed_ens) = gnki_update(ekp, u, g, y, obs_noise_cov) diff --git a/src/Observations.jl b/src/Observations.jl index 93bde34b7..2cc24778e 100644 --- a/src/Observations.jl +++ b/src/Observations.jl @@ -25,10 +25,13 @@ export get_samples, get_method, get_rng, get_minibatch_size, + get_length_epoch, get_observations, + get_minibatch_index, get_current_minibatch_index, get_minibatcher, update_minibatch!, + get_minibatch, get_current_minibatch, no_minibatcher, tsvd_mat, @@ -130,7 +133,7 @@ $(TYPEDSIGNATURES) For a given matrix `X` and rank `r`, return the truncated SVD for X as a LinearAlgebra.jl `SVD` object. Setting `return_inverse=true` also return it's psuedoinverse X⁺. """ -function tsvd_mat(X, r::Int; return_inverse = false, tsvd_kwargs...) +function tsvd_mat(X, r::Int; return_inverse = false, quiet = false, tsvd_kwargs...) # Note, must only use tsvd approximation when rank < minimum dimension of X or you get very poor approximation. if isa(X, UniformScaling) if return_inverse @@ -142,7 +145,7 @@ function tsvd_mat(X, r::Int; return_inverse = false, tsvd_kwargs...) rx = rank(X) mindim = minimum(size(X)) if rx <= r - if rx < r + if rx < r && !quiet @warn( "Requested truncation to rank $(r) for an input matrix of rank $(rx). Performing (truncated) SVD for rank $(rx) matrix." ) @@ -169,7 +172,7 @@ function tsvd_mat(X, r::Int; return_inverse = false, tsvd_kwargs...) end end -function tsvd_mat(X; return_inverse = false, tsvd_kwargs...) +function tsvd_mat(X; return_inverse = false, quiet = false, tsvd_kwargs...) if isa(X, UniformScaling) throw( ArgumentError( @@ -177,7 +180,7 @@ function tsvd_mat(X; return_inverse = false, tsvd_kwargs...) ), ) end - return tsvd_mat(X, rank(X); return_inverse = return_inverse, tsvd_kwargs...) + return tsvd_mat(X, rank(X); return_inverse = return_inverse, quiet = quiet, tsvd_kwargs...) end """ @@ -190,12 +193,13 @@ function tsvd_cov_from_samples( r::Int; data_are_columns::Bool = true, return_inverse = false, + quiet = false, tsvd_kwargs..., ) where {AM <: AbstractMatrix} mat = data_are_columns ? sample_mat : permutedims(sample_mat, (2, 1)) N = size(mat, 2) - if N > size(mat, 1) + if N > size(mat, 1) && !quiet @warn( "SVD representation is efficient when estimating high-dimensional covariance with few samples. \n here # samples is $(N), while the space dimension is $(size(mat,1)), and representation will be inefficient." ) @@ -220,6 +224,7 @@ function tsvd_cov_from_samples( sample_mat::AM; data_are_columns::Bool = true, return_inverse = false, + quiet = false, tsvd_kwargs..., ) where {AM <: AbstractMatrix} # (need to compute this to get the rank as debiasing can change it) @@ -231,6 +236,7 @@ function tsvd_cov_from_samples( rk; data_are_columns = data_are_columns, return_inverse = return_inverse, + quiet = quiet, tsvd_kwargs..., ) end @@ -1057,15 +1063,34 @@ end """ $(TYPEDSIGNATURES) +gets the number of minibatches in an epoch +""" +function get_length_epoch(os::OS) where {OS <: ObservationSeries} + return length(get_minibatches(os)[1]) +end + +""" +$(TYPEDSIGNATURES) + +returns the minibatch_index `Dict("epoch"=> x, "minibatch" => y)`, for a given `iteration` +""" +function get_minibatch_index(os::OS, iteration::Int) where {OS <: ObservationSeries} + len_epoch = get_length_epoch(os) + return Dict("epoch" => ((iteration - 1) ÷ len_epoch) + 1, "minibatch" => ((iteration - 1) % len_epoch) + 1) +end + +""" +$(TYPEDSIGNATURES) + Within an epoch: iterates the current minibatch index by one. At the end of an epoch: obtains a new epoch of minibatches from the `Minibatcher` updates the epoch index by one, and minibatch index to one. """ function update_minibatch!(os::OS) where {OS <: ObservationSeries} index = get_current_minibatch_index(os) - minibatches_in_epoch = get_minibatches(os)[index["epoch"]] + len_epoch = get_length_epoch(os) # take new batch - if length(minibatches_in_epoch) >= index["minibatch"] + 1 + if len_epoch >= index["minibatch"] + 1 index["minibatch"] += 1 #update by 1 else index["minibatch"] = 1 # set to 1 @@ -1075,7 +1100,23 @@ function update_minibatch!(os::OS) where {OS <: ObservationSeries} new_epoch = create_new_epoch!(minibatcher, collect(1:length(get_observations(os)))) # create a new sweep of minibatches push!(minibatches, new_epoch) end +end + +""" +$(TYPEDSIGNATURES) +get the minibatch for a given minibatch index (`Dict("epoch"=> x, "minibatch" => y)`), or iteration `Int`. If `nothing` is provided as an iteration then the current minibatch is returned +""" +function get_minibatch(os::OS, it_or_mbi::IorDorN) where {OS <: ObservationSeries, IorDorN <: Union{Int, Dict, Nothing}} + if isnothing(it_or_mbi) + return get_current_minibatch(os) + else + index = isa(it_or_mbi, Dict) ? it_or_mbi : get_minibatch_index(os, it_or_mbi) + minibatches = get_minibatches(os) + epoch = index["epoch"] + mini = index["minibatch"] + return minibatches[epoch][mini] + end end # stored as vector of vectors @@ -1093,10 +1134,10 @@ end """ $(TYPEDSIGNATURES) -if `build=true` then gets the observed sample, stacked over the current minibatch. `build=false` lists the `samples` for all observations +if `build=true` then gets the observed sample, stacked over the minibatch at `iteration`. `build=false` lists the `samples` for all observations. If `isnothing(iteration)` or not defined then the current iteration is used. """ -function get_obs(os::OS; build = true) where {OS <: ObservationSeries} - minibatch = get_current_minibatch(os) # gives the indices of the minibatch +function get_obs(os::OS, iteration::IorN; build = true) where {OS <: ObservationSeries, IorN <: Union{Int, Nothing}} + minibatch = get_minibatch(os, iteration) minibatch_length = length(minibatch) observations_vec = get_observations(os)[minibatch] # gives observation objects if minibatch_length == 1 @@ -1120,10 +1161,22 @@ end """ $(TYPEDSIGNATURES) -if `build=true` then gets the observation covariance matrix, blocked over the current minibatch. `build=false` lists the `covs` for all observations +if `build=true` then gets the observed sample, stacked over the current minibatch. `build=false` lists the `samples` for all observations. +""" +get_obs(os::OS; kwargs...) where {OS <: ObservationSeries} = get_obs(os, nothing; kwargs...) + + +""" +$(TYPEDSIGNATURES) + +if `build=true` then gets the observation covariance matrix, blocked over the minibatch at `iteration`. `build=false` lists the `covs` for all observations. If `isnothing(iteration)` or not defined then the current iteration is used. """ -function get_obs_noise_cov(os::OS; build = true) where {OS <: ObservationSeries} - minibatch = get_current_minibatch(os) # gives the indices of the minibatch +function get_obs_noise_cov( + os::OS, + iteration::IorN; + build = true, +) where {OS <: ObservationSeries, IorN <: Union{Int, Nothing}} + minibatch = get_minibatch(os, iteration) minibatch_length = length(minibatch) observations_vec = get_observations(os)[minibatch] # gives observation objects if minibatch_length == 1 # if only 1 sample then return it @@ -1154,10 +1207,22 @@ end """ $(TYPEDSIGNATURES) -if `build=true` then gets the inverse of the observation covariance matrix, blocked over the current minibatch. `build=false` lists the `inv_covs` for all observations +if `build=true` then gets the observation covariance matrix, blocked over the current minibatch. `build=false` lists the `covs` for all observations """ -function get_obs_noise_cov_inv(os::OS; build = true) where {OS <: ObservationSeries} - minibatch = get_current_minibatch(os) # gives the indices of the minibatch +get_obs_noise_cov(os::OS; kwargs...) where {OS <: ObservationSeries} = get_obs_noise_cov(os, nothing; kwargs...) + + +""" +$(TYPEDSIGNATURES) + +if `build=true` then gets the inverse of the observation covariance matrix, blocked over minibatch at `iteration`. `build=false` lists the `inv_covs` for all observations. If `isnothing(iteration)` or not defined then the current iteration is used. +""" +function get_obs_noise_cov_inv( + os::OS, + iteration::IorN; + build = true, +) where {OS <: ObservationSeries, IorN <: Union{Int, Nothing}} + minibatch = get_minibatch(os, iteration) minibatch_length = length(minibatch) observations_vec = get_observations(os)[minibatch] # gives observation objects if minibatch_length == 1 # if only 1 sample then return it @@ -1185,6 +1250,13 @@ function get_obs_noise_cov_inv(os::OS; build = true) where {OS <: ObservationSer end +""" +$(TYPEDSIGNATURES) + +if `build=true` then gets the inverse of the observation covariance matrix, blocked over the current minibatch. `build=false` lists the `inv_covs` for all observations. +""" +get_obs_noise_cov_inv(os::OS; kwargs...) where {OS <: ObservationSeries} = get_obs_noise_cov_inv(os, nothing; kwargs...) + get_cov_size(am::AM) where {AM <: AbstractMatrix} = size(am, 1) function get_cov_size(am::SVD) diff --git a/src/SparseEnsembleKalmanInversion.jl b/src/SparseEnsembleKalmanInversion.jl index b8500b980..acabf4d7e 100644 --- a/src/SparseEnsembleKalmanInversion.jl +++ b/src/SparseEnsembleKalmanInversion.jl @@ -37,6 +37,10 @@ function SparseInversion( return SparseInversion{FT}(γ, threshold_value, uc_idx, reg) end +get_prior_mean(process::SparseInversion) = nothing +get_prior_cov(process::SparseInversion) = nothing + + function FailureHandler(process::SparseInversion, method::IgnoreFailures) failsafe_update(ekp, u, g, y, obs_noise_cov, failed_ens) = sparse_eki_update(ekp, u, g, y, obs_noise_cov) return FailureHandler{SparseInversion, IgnoreFailures}(failsafe_update) diff --git a/src/UnscentedKalmanInversion.jl b/src/UnscentedKalmanInversion.jl index 7bc9e7e06..5bc412102 100644 --- a/src/UnscentedKalmanInversion.jl +++ b/src/UnscentedKalmanInversion.jl @@ -146,8 +146,6 @@ function Unscented( cov_weights[1] = λ / (N_par + λ) + 1 - α^2 + 2.0 cov_weights[2:N_ens] .= 1 / (2 * (N_par + λ)) - - elseif sigma_points == "simplex" c_weights = zeros(FT, N_par, N_ens) @@ -288,6 +286,9 @@ function EnsembleKalmanProcess( return EnsembleKalmanProcess(observation, process; kwargs...) end +get_prior_mean(process::UorTU) where {UorTU <: Union{Unscented, TransformUnscented}} = process.prior_mean +get_prior_cov(process::UorTU) where {UorTU <: Union{Unscented, TransformUnscented}} = process.prior_cov + function FailureHandler(process::Unscented, method::IgnoreFailures) function failsafe_update(uki, u, g, u_idx, g_idx, failed_ens) #perform analysis on the model runs @@ -618,7 +619,7 @@ function construct_perturbation( if isa(x, AbstractMatrix{FT}) @assert isa(x_mean, AbstractVector{FT}) xx_pert = zeros(size(x)) - for i in 1:size(xx_pert, 2) + for i in 2:size(xx_pert, 2) # first column always zero (as it is the mean) xx_pert[:, i] = sqrt(cov_weights[i]) * (x[:, i] .- x_mean) end else @@ -626,7 +627,7 @@ function construct_perturbation( N_ens = length(x) xx_pert = zeros(N_ens) - for i in 1:N_ens + for i in 2:N_ens # first entry always zero (as it is the mean) xx_pert[i] += sqrt(cov_weights[i]) * (x[i] .- x_mean) end end @@ -836,16 +837,78 @@ function get_u_cov( return get_process(uki).uu_cov[iteration] end -function compute_error!( + +function compute_loss_at_mean( uki::EnsembleKalmanProcess{FT, IT, UorTU}, ) where {FT <: AbstractFloat, IT <: Int, UorTU <: Union{Unscented, TransformUnscented}} mean_g = get_process(uki).obs_pred[end] diff = get_obs(uki) - mean_g X = lmul_obs_noise_cov_inv(uki, diff) - newerr = dot(diff, X) - push!(get_error(uki), newerr) + newerr = 1.0 / length(mean_g) * dot(diff, X) + return newerr +end + +function compute_unweighted_loss_at_mean( + uki::EnsembleKalmanProcess{FT, IT, UorTU}, +) where {FT <: AbstractFloat, IT <: Int, UorTU <: Union{Unscented, TransformUnscented}} + mean_g = get_process(uki).obs_pred[end] + diff = get_obs(uki) - mean_g + newerr = 1.0 / length(mean_g) * dot(diff, diff) + return newerr end +function compute_crps( + uki::EnsembleKalmanProcess{FT, IT, UorTU}, +) where {FT <: AbstractFloat, IT <: Int, UorTU <: Union{Unscented, TransformUnscented}} + g = get_g_final(uki) + successful_ens, _ = split_indices_by_success(g) + g_mean = construct_successful_mean(uki, g, successful_ens) + diff = get_obs(uki) - g_mean + if length(g_mean) > length(successful_ens) # dim > ens + # get svd directly from the perturbations (not samples) + g_perturb = construct_successful_perturbation(uki, g, g_mean, successful_ens)[:, 2:end] # first column zeros + g_svd = svd(g_perturb) # Note this svd gives, .S are sqrt-evals of cov-g + white_diff = 1 ./ g_svd.S .* g_svd.U' * diff # as g_perturb was tall + + dist = Normal(0, 1) + indep_crps = white_diff .* (2 .* cdf.(dist, white_diff) .- 1) .+ 2 * pdf.(dist, white_diff) .- 1 ./ sqrt(π) + avg_crps = 1 ./ length(g_svd.S) * sum(g_svd.S .* indep_crps) + return avg_crps + + else # dim < ens + g_cov = construct_successful_cov(uki, g, g_mean, successful_ens) + g_svd = svd(g_cov) # Note this svd gives, .S are evals + + white_diff = 1 ./ sqrt.(g_svd.S) .* g_svd.Vt * diff # ~N(0,I) + + dist = Normal(0, 1) + indep_crps = white_diff .* (2 .* cdf.(dist, white_diff) .- 1) .+ 2 * pdf.(dist, white_diff) .- 1 ./ sqrt(π) + avg_crps = 1 ./ length(g_svd.S) * sum(sqrt.(g_svd.S) .* indep_crps) + return avg_crps + end + +end + +""" +For Unscented processes it doesn't make sense to average RMSE at sigma points, so it is evaluated at the mean only, where it is exactly equal to `sqrt(compute_loss_at_mean(uki))` +""" +function compute_average_rmse( + uki::EnsembleKalmanProcess{FT, IT, UorTU}, +) where {FT <: AbstractFloat, IT <: Int, UorTU <: Union{Unscented, TransformUnscented}} + return sqrt(compute_loss_at_mean(uki)) +end + +""" +For Unscented processes it doesn't make sense to average unweighted RMSE at sigma points, so it is evaluated at the mean only, where it is exactly equal to `sqrt(compute_unweighted_loss_at_mean(uki))` +""" +function compute_average_unweighted_rmse( + uki::EnsembleKalmanProcess{FT, IT, UorTU}, +) where {FT <: AbstractFloat, IT <: Int, UorTU <: Union{Unscented, TransformUnscented}} + return sqrt(compute_unweighted_loss_at_mean(uki)) +end + + + function Gaussian_2d( u_mean::AbstractVector{FT}, diff --git a/src/Visualize.jl b/src/Visualize.jl new file mode 100644 index 000000000..e91593aad --- /dev/null +++ b/src/Visualize.jl @@ -0,0 +1,29 @@ +module Visualize + +function plot_parameter_distribution end + +function plot_error_over_iters end + +function plot_error_over_iters! end + +function plot_error_over_time end + +function plot_error_over_time! end + +function plot_ϕ_over_iters end + +function plot_ϕ_over_iters! end + +function plot_ϕ_over_time end + +function plot_ϕ_over_time! end + +function plot_ϕ_mean_over_iters! end + +function plot_ϕ_mean_over_iters end + +function plot_ϕ_mean_over_time! end + +function plot_ϕ_mean_over_time end + +end diff --git a/test/EnsembleKalmanProcess/runtests.jl b/test/EnsembleKalmanProcess/runtests.jl index 9415e5d0f..735977079 100644 --- a/test/EnsembleKalmanProcess/runtests.jl +++ b/test/EnsembleKalmanProcess/runtests.jl @@ -626,6 +626,9 @@ end ## some getters in EKP @test get_obs(ekiobj) == y_obs @test get_obs_noise_cov(ekiobj) == Γy + @test get_obs(ekiobj, 1) == y_obs + @test get_obs_noise_cov(ekiobj, 1) == Γy + @test get_obs_noise_cov_inv(ekiobj) == get_obs_noise_cov_inv(ekiobj, 1) g_ens = G(get_ϕ_final(prior, ekiobj)) g_ens_t = permutedims(g_ens, (2, 1)) @@ -722,11 +725,20 @@ end end push!(u_i_vec, get_u_final(ekiobj)) + algtime = get_algorithm_time(ekiobj) + @test algtime == accumulate(+, get_Δt(ekiobj)) + @test get_u_prior(ekiobj) == u_i_vec[1] @test get_u(ekiobj) == u_i_vec @test isequal(get_g(ekiobj), g_ens_vec) @test isequal(get_g_final(ekiobj), g_ens_vec[end]) - @test isequal(get_error(ekiobj), ekiobj.error) + @test isequal(get_error_metrics(ekiobj), ekiobj.error_metrics) + + # get_error should give the appropriate loss + @test isequal(get_error(ekiobj), get_error_metrics(ekiobj)["loss"]) + @test isequal(get_error(ekiobj_inf), get_error_metrics(ekiobj_inf)["bayes_loss"]) + + # EKI results: Test if ensemble has collapsed toward the true parameter # values @@ -829,7 +841,16 @@ end for (i_prob, inv_problem, impose_prior, update_freq) in zip(1:length(inv_problems), inv_problems, impose_priors, update_freqs) - y_obs, G, Γy, A = inv_problem + if i_prob == 1 # we do one of the problems with n_obs < n_ens (for code coverage) + y_obs, Gold, Γy, A = inv_problem + # input_dim = 2 -> n_ens = 4 or 7 + y_obs = y_obs[1:3] + Γy = Γy[1:3, 1:3] + A = A[1:3, :] + G(x) = Gold(x)[1:3, :] + else + y_obs, G, Γy, A = inv_problem + end scheduler = DataMisfitController(on_terminate = "continue") #will need to be copied as stores run information inside scheduler_simplex = DefaultScheduler(0.05) #will need to be copied as stores run information inside @@ -964,7 +985,7 @@ end @test get_u(ekpobj) == u_i_vec @test isequal(get_g(ekpobj), g_ens_vec) @test isequal(get_g_final(ekpobj), g_ens_vec[end]) - @test isequal(get_error(ekpobj), ekpobj.error) + @test isequal(get_error_metrics(ekpobj), ekpobj.error_metrics) @test isa(construct_mean(ekpobj, rand(rng, 2 * n_par + 1)), Float64) @test isa(construct_mean(ekpobj, rand(rng, 5, 2 * n_par + 1)), Vector{Float64}) @@ -1139,7 +1160,12 @@ end @test get_u(ekiobj) == u_i_vec @test isequal(get_g(ekiobj), g_ens_vec) @test isequal(get_g_final(ekiobj), g_ens_vec[end]) - @test isequal(get_error(ekiobj), ekiobj.error) + @test isequal(get_error_metrics(ekiobj), ekiobj.error_metrics) + + # get_error should give the appropriate loss + @test isequal(get_error(ekiobj), get_error_metrics(ekiobj)["loss"]) + @test isequal(get_error(ekiobj_inf), get_error_metrics(ekiobj_inf)["bayes_loss"]) + # ETKI results: Test if ensemble has collapsed toward the true parameter # values @@ -1490,7 +1516,7 @@ end @test get_u(gnkiobj) == u_i_vec @test isequal(get_g(gnkiobj), g_ens_vec) # can deal with NaNs @test isequal(get_g_final(gnkiobj), g_ens_vec[end]) - @test isequal(get_error(gnkiobj), gnkiobj.error) + @test isequal(get_error_metrics(gnkiobj), gnkiobj.error_metrics) # GNKI results: Test if ensemble has collapsed toward the true parameter # values diff --git a/test/Observations/runtests.jl b/test/Observations/runtests.jl index 30037d23f..f5dd2215f 100644 --- a/test/Observations/runtests.jl +++ b/test/Observations/runtests.jl @@ -130,7 +130,6 @@ using EnsembleKalmanProcesses end @test onci_full == full - end @testset "covariance utilities" begin @@ -362,6 +361,8 @@ end rng = Random.MersenneTwister(11023) given_batches = [collect(((i - 1) * batch_size + 1):(i * batch_size)) for i in 1:Int(floor(maximum(sample_ids) / batch_size))] + n_batches_per_epoch = maximum(sample_ids) / batch_size + minibatcher = FixedMinibatcher(given_batches, "random", copy(rng)) observation_series = ObservationSeries(obs_vec, minibatcher, series_names, metadata = metadata) minibatcher = FixedMinibatcher(given_batches, "random", copy(rng)) @@ -377,6 +378,10 @@ end @test get_observations(observation_series) == obs_vec @test get_minibatches(observation_series) == [new_epoch] @test get_current_minibatch_index(observation_series) == Dict("epoch" => 1, "minibatch" => 1) + @test get_length_epoch(observation_series) == n_batches_per_epoch + @test get_minibatch_index(observation_series, 1) == Dict("epoch" => 1, "minibatch" => 1) + @test get_minibatch_index(observation_series, 8) == + Dict("epoch" => ((8 - 1) ÷ n_batches_per_epoch) + 1, "minibatch" => ((8 - 1) % n_batches_per_epoch) + 1) @test get_minibatcher(observation_series) == minibatcher @test get_names(observation_series) == series_names @test get_metadata(observation_series) == metadata @@ -398,14 +403,17 @@ end # test the minibatch updating with epochs @test get_current_minibatch(observation_series) == new_epoch[1] + @test get_minibatch(observation_series, 1) == new_epoch[1] for i in 2:5 # we have 5 batches update_minibatch!(observation_series) @test get_current_minibatch(observation_series) == new_epoch[i] + @test get_minibatch(observation_series, i) == new_epoch[i] end update_minibatch!(observation_series) new_epoch2 = create_new_epoch!(minibatcher, given_batches) @test get_current_minibatch_index(observation_series) == Dict("epoch" => 2, "minibatch" => 1) @test get_current_minibatch(observation_series) == new_epoch2[1] + @test get_minibatch(observation_series, 6) == new_epoch2[1] @test get_minibatches(observation_series) == [new_epoch, new_epoch2] # test the no minibatch option @@ -424,10 +432,13 @@ end # get_obs (def: build = true) mb = new_epoch2[1] + mb_idx = 6 obs_minibatch = get_obs(observation_series, build = false) @test get_obs.(obs_vec[mb], build = false) == obs_minibatch obs_minibatch_stack = get_obs(observation_series) @test reduce(vcat, get_obs.(obs_vec[mb])) == obs_minibatch_stack + @test get_obs(observation_series, mb_idx, build = false) == obs_minibatch + @test get_obs(observation_series) == obs_minibatch_stack # get_obs_noise_cov obs_noise_cov_minibatch_blocks = get_obs_noise_cov(observation_series, build = false) @@ -437,6 +448,7 @@ end end minibatch_covs = reduce(vcat, minibatch_covs) @test minibatch_covs == obs_noise_cov_minibatch_blocks + @test minibatch_covs == get_obs_noise_cov(observation_series, mb_idx, build = false) obs_noise_cov_minibatch_full = get_obs_noise_cov(observation_series) minibatch_covs = [] @@ -452,6 +464,11 @@ end idx_min[1] += block_sizes[i] end @test minibatch_cov_full == obs_noise_cov_minibatch_full + @test minibatch_cov_full == get_obs_noise_cov(observation_series, mb_idx) + + # just test the indexing a last time + @test get_obs_noise_cov_inv(observation_series) == get_obs_noise_cov_inv(observation_series, mb_idx) + end @@ -496,7 +513,7 @@ end Γfullinv[13:21, 13:21] = svd_cov.U * Diagonal(1.0 ./ svd_cov.S) * svd_cov.U' Γfullinv[22:33, 22:33] = inv(Γfull[22:33, 22:33]) #small enough to compute directly Γinv = get_obs_noise_cov_inv(observation_series) - @test norm(Γfullinv - Γinv) < 1e-12 + @test norm(Γfullinv - Γinv) < 1e-11 # Test utilities for multiplying covariances without building them # test lmul_obs_noise_cov @@ -504,13 +521,13 @@ end X = randn(size(Γ, 1), 50) test1 = Γ * Xvec test2 = Γ * X - @test norm(test1 - lmul_obs_noise_cov(observation_series, Xvec)) < 1e-12 - @test norm(test2 - lmul_obs_noise_cov(observation_series, X)) < 1e-12 + @test norm(test1 - lmul_obs_noise_cov(observation_series, Xvec)) < 1e-11 + @test norm(test2 - lmul_obs_noise_cov(observation_series, X)) < 1e-11 test1 = Γinv * Xvec test2 = Γinv * X - @test norm(test1 - lmul_obs_noise_cov_inv(observation_series, Xvec)) < 1e-12 - @test norm(test2 - lmul_obs_noise_cov_inv(observation_series, X)) < 1e-12 + @test norm(test1 - lmul_obs_noise_cov_inv(observation_series, Xvec)) < 1e-11 + @test norm(test2 - lmul_obs_noise_cov_inv(observation_series, X)) < 1e-11 # test the pre-indexed versions: # cov @@ -522,12 +539,12 @@ end # The function is writte to be applied to pre-trimmed "X"s lmul_obs_noise_cov!(out1, observation_series, Xvec[g_idx], g_idx) lmul_obs_noise_cov!(out2, observation_series, X[g_idx, :], g_idx) - @test norm(test1 - out1) < 1e-12 - @test norm(test2 - out2) < 1e-12 + @test norm(test1 - out1) < 1e-11 + @test norm(test2 - out2) < 1e-11 lmul_obs_noise_cov!(out1, observation_series, Xvec, g_idx) lmul_obs_noise_cov!(out2, observation_series, X, g_idx) - @test norm(test1 - out1) < 1e-12 - @test norm(test2 - out2) < 1e-12 + @test norm(test1 - out1) < 1e-11 + @test norm(test2 - out2) < 1e-11 # cov_inv @@ -538,11 +555,11 @@ end # In code this is applied to pre-trimmed "X"s or not lmul_obs_noise_cov_inv!(out1, observation_series, Xvec[g_idx], g_idx) lmul_obs_noise_cov_inv!(out2, observation_series, X[g_idx, :], g_idx) - @test norm(test1 - out1) < 1e-12 - @test norm(test2 - out2) < 1e-12 + @test norm(test1 - out1) < 1e-11 + @test norm(test2 - out2) < 1e-11 lmul_obs_noise_cov_inv!(out1, observation_series, Xvec, g_idx) lmul_obs_noise_cov_inv!(out2, observation_series, X, g_idx) - @test norm(test1 - out1) < 1e-12 - @test norm(test2 - out2) < 1e-12 + @test norm(test1 - out1) < 1e-11 + @test norm(test2 - out2) < 1e-11 end diff --git a/test/SparseInversion/runtests.jl b/test/SparseInversion/runtests.jl index e55e73b81..d11af0d58 100644 --- a/test/SparseInversion/runtests.jl +++ b/test/SparseInversion/runtests.jl @@ -124,7 +124,8 @@ include("../EnsembleKalmanProcess/inverse_problem.jl") @test get_u(ekiobj) == u_i_vec @test isequal(get_g(ekiobj), g_ens_vec) @test isequal(get_g_final(ekiobj), g_ens_vec[end]) - @test isequal(get_error(ekiobj), ekiobj.error) + @test isequal(get_error_metrics(ekiobj), ekiobj.error_metrics) + @test isequal(get_error(ekiobj), get_error_metrics(ekiobj)["loss"]) # EKI results: Test if ensemble has collapsed toward the true constrained parameter # values diff --git a/test/Visualize/runtests.jl b/test/Visualize/runtests.jl new file mode 100644 index 000000000..18f940ec1 --- /dev/null +++ b/test/Visualize/runtests.jl @@ -0,0 +1,225 @@ +if TEST_PLOT_OUTPUT + using Test + using CairoMakie + + using LinearAlgebra + using EnsembleKalmanProcesses + using EnsembleKalmanProcesses.ParameterDistributions + + @testset "Helper functions" begin + prior_u1 = constrained_gaussian("positive_with_mean_1", 2, 1, 0, Inf) + prior_u2 = constrained_gaussian("four_with_spread_3", 0, 5, -Inf, Inf, repeats = 3) + prior_u3 = constrained_gaussian("positive_with_mean_1", 0, 3, -Inf, Inf) + prior = combine_distributions([prior_u1, prior_u2, prior_u3]) + + # Since these functions are not exported, we need to access the + # functions like this + ext = Base.get_extension(EnsembleKalmanProcesses, :EnsembleKalmanProcessesMakieExt) + @test ext._get_dim_of_dist(prior, 1) == 1 + @test ext._get_dim_of_dist(prior, 2) == 1 + @test ext._get_dim_of_dist(prior, 3) == 2 + @test ext._get_dim_of_dist(prior, 4) == 3 + @test ext._get_dim_of_dist(prior, 5) == 1 + @test_throws ErrorException ext._get_dim_of_dist(prior, 6) + + @test ext._get_prior_name(prior, 1) == "positive_with_mean_1" + @test ext._get_prior_name(prior, 2) == "four_with_spread_3" + @test ext._get_prior_name(prior, 3) == "four_with_spread_3" + @test ext._get_prior_name(prior, 4) == "four_with_spread_3" + @test ext._get_prior_name(prior, 5) == "positive_with_mean_1" + @test_throws ErrorException ext._get_prior_name(prior, 6) + + end + @testset "Makie ploting" begin + # Access plots at tmp_dir + tmp_dir = mktempdir(cleanup = false) + @info "Tempdir", tmp_dir + + # Fix seed, so the plots do not differ from run to run + import Random + rng_seed = 1234 + rng = Random.MersenneTwister(rng_seed) + + G(u) = [1.0 / abs(u[1]), sum(u[2:3]), u[3], u[1]^2 - u[2] - u[3], u[1], 5.0] .+ 0.1 * randn(6) + true_u = [3, 1, 2] + y = G(true_u) + Γ = (0.1)^2 * I + + prior_u1 = constrained_gaussian("positive_with_mean_1", 2, 1, 0, Inf) + prior_u2 = constrained_gaussian("two_with_spread_2", 0, 5, -Inf, Inf, repeats = 2) + prior = combine_distributions([prior_u1, prior_u2]) + + N_ensemble = 6 + initial_ensemble = construct_initial_ensemble(prior, N_ensemble) + ekp = EnsembleKalmanProcess(initial_ensemble, y, Γ, Inversion(), verbose = true) + + N_iterations = 8 + for i in 1:N_iterations + params_i = get_ϕ_final(prior, ekp) + + G_matrix = hcat( + [G(params_i[:, i]) for i in 1:N_ensemble]..., # Parallelize here! + ) + + update_ensemble!(ekp, G_matrix) + end + + # Test functions for plotting distributions and errors over time or iterations + fig1 = CairoMakie.Figure(size = (200 * 4, 200 * 7)) + layout = fig1[1, 1:2] = CairoMakie.GridLayout() + EnsembleKalmanProcesses.Visualize.plot_parameter_distribution(layout, prior) + EnsembleKalmanProcesses.Visualize.plot_error_over_iters( + fig1[2, 1], + ekp, + color = :tomato, + axis = (xlabel = "Iterations [added by axis keyword argument]",), + ) + EnsembleKalmanProcesses.Visualize.plot_error_over_time( + fig1[2, 2], + ekp, + linestyle = :dash, + auto_log_scale = true, + error_metric = "bayes_loss", + ) + ax1 = CairoMakie.Axis(fig1[3, 1], title = "Error over iterations (called from mutating function)") + ax2 = CairoMakie.Axis(fig1[3, 2], title = "Error over time (called from mutating function)") + EnsembleKalmanProcesses.Visualize.plot_error_over_iters!(ax1, ekp, color = :aquamarine, auto_log_scale = true) + EnsembleKalmanProcesses.Visualize.plot_error_over_time!( + ax2, + ekp, + linestyle = :dashdotdot, + auto_log_scale = true, + error_metric = "bayes_loss", + ) + save(joinpath(tmp_dir, "priors_and_errors.png"), fig1) + + # Test functions for plotting ϕ over time or iterations + fig2 = CairoMakie.Figure(size = (400 * 2, 400 * 2)) + EnsembleKalmanProcesses.Visualize.plot_ϕ_over_iters(fig2[1, 1], ekp, prior, 1) + EnsembleKalmanProcesses.Visualize.plot_ϕ_over_time(fig2[1, 2], ekp, prior, 2) + ax1 = CairoMakie.Axis(fig2[2, 1], title = "Constrained parameters (called from mutating function)") + ax2 = CairoMakie.Axis(fig2[2, 2], title = "Constrained parameters (called from mutating function)") + EnsembleKalmanProcesses.Visualize.plot_ϕ_over_iters!(ax1, ekp, prior, 1) + EnsembleKalmanProcesses.Visualize.plot_ϕ_over_time!(ax2, ekp, prior, 2) + + save(joinpath(tmp_dir, "phi.png"), fig2) + + # Test functions for plotting mean ϕ over time or iterations + fig3 = CairoMakie.Figure(size = (400 * 2, 400 * 2)) + EnsembleKalmanProcesses.Visualize.plot_ϕ_mean_over_iters(fig3[1, 1], ekp, prior, 1) + EnsembleKalmanProcesses.Visualize.plot_ϕ_mean_over_time(fig3[1, 2], ekp, prior, 2) + ax1 = CairoMakie.Axis(fig3[2, 1], xlabel = "Iters (called from mutating function)") + ax2 = CairoMakie.Axis(fig3[2, 2], xlabel = "Time (called from mutating function)") + EnsembleKalmanProcesses.Visualize.plot_ϕ_mean_over_iters!(ax1, ekp, prior, 1) + EnsembleKalmanProcesses.Visualize.plot_ϕ_mean_over_time!(ax2, ekp, prior, 2) + save(joinpath(tmp_dir, "mean_phi.png"), fig3) + + # Test functions for plotting mean and std of ϕ over time or iterations + fig4 = CairoMakie.Figure(size = (400 * 2, 400 * 2)) + EnsembleKalmanProcesses.Visualize.plot_ϕ_mean_over_iters( + fig4[1, 1], + ekp, + prior, + 1, + plot_std = true, + band_kwargs = (alpha = 0.2,), + color = :red, + ) + EnsembleKalmanProcesses.Visualize.plot_ϕ_mean_over_time( + fig4[1, 2], + ekp, + prior, + 2, + plot_std = true, + line_kwargs = (linestyle = :dash,), + band_kwargs = (alpha = 0.2,), + color = :purple, + ) + ax1 = CairoMakie.Axis(fig4[2, 1], xlabel = "Iters (called from mutating function)") + ax2 = CairoMakie.Axis(fig4[2, 2], xlabel = "Time (called from mutating function)") + EnsembleKalmanProcesses.Visualize.plot_ϕ_mean_over_iters!( + ax1, + ekp, + prior, + 1, + plot_std = true, + band_kwargs = (color = (:blue, 0.2),), + color = :red, + ) + EnsembleKalmanProcesses.Visualize.plot_ϕ_mean_over_time!( + ax2, + ekp, + prior, + 2, + plot_std = true, + line_kwargs = (linestyle = :dash,), + band_kwargs = (color = (:green, 0.2),), + color = :purple, + ) + + save(joinpath(tmp_dir, "mean_and_std_phi.png"), fig4) + + # Test plotting functions for plotting ϕ by name + fig5 = CairoMakie.Figure(size = (400 * 2, 400 * 4)) + EnsembleKalmanProcesses.Visualize.plot_ϕ_over_iters((fig5[1, 1], fig5[1, 2]), ekp, prior, "two_with_spread_2") + EnsembleKalmanProcesses.Visualize.plot_ϕ_over_time((fig5[2, 1], fig5[2, 2]), ekp, prior, "two_with_spread_2") + EnsembleKalmanProcesses.Visualize.plot_ϕ_mean_over_iters( + (fig5[3, 1], fig5[3, 2]), + ekp, + prior, + "two_with_spread_2", + linewidth = 1.5, + ) + EnsembleKalmanProcesses.Visualize.plot_ϕ_mean_over_time( + (fig5[4, 1], fig5[4, 2]), + ekp, + prior, + "two_with_spread_2", + linewidth = 3, + ) + save(joinpath(tmp_dir, "plot_by_name.png"), fig5) + + # Error handling + @test_throws ErrorException EnsembleKalmanProcesses.Visualize.plot_ϕ_mean_over_time( + (fig5[5, 1], fig5[5, 2]), + ekp, + prior, + "name not present", + ) + @test_throws ErrorException EnsembleKalmanProcesses.Visualize.plot_ϕ_mean_over_time( + (fig5[6, 1],), + ekp, + prior, + "two_with_spread_2", + ) + + # Test different plotting signatures + # We do not test all possible combinations because there are too many to + # test, so we only test plot_fn(args...; kwargs...) + mkdir(joinpath(tmp_dir, "diff_plot_signatures")) + error_fns = [ + EnsembleKalmanProcesses.Visualize.plot_error_over_iters, + EnsembleKalmanProcesses.Visualize.plot_error_over_time, + ] + for error_fn in error_fns + fig_diff_signs, _, _ = error_fn(ekp) + save(joinpath(tmp_dir, "diff_plot_signatures", "$error_fn.png"), fig_diff_signs) + end + + ϕ_fns = [ + EnsembleKalmanProcesses.Visualize.plot_ϕ_over_iters, + EnsembleKalmanProcesses.Visualize.plot_ϕ_over_time, + EnsembleKalmanProcesses.Visualize.plot_ϕ_mean_over_iters, + EnsembleKalmanProcesses.Visualize.plot_ϕ_mean_over_time, + ] + for ϕ_fn in ϕ_fns + fig_diff_signs, _, _ = ϕ_fn(ekp, prior, 1) + save(joinpath(tmp_dir, "diff_plot_signatures", "$ϕ_fn.png"), fig_diff_signs) + end + + # Access plots at tmp_dir + # Print info again since it is a little tedious to find it if you miss + # it the first time + @info "Tempdir", tmp_dir + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 90665380f..fe002104d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,6 +35,7 @@ end "TOMLInterface", "SparseInversion", "Inflation", + "Visualize", ] if all_tests || has_submodule(submodule) || "EnsembleKalmanProcesses" in ARGS include_test(submodule)