File tree Expand file tree Collapse file tree 4 files changed +10
-19
lines changed
Expand file tree Collapse file tree 4 files changed +10
-19
lines changed Original file line number Diff line number Diff line change @@ -147,20 +147,11 @@ namespace blackdrops {
147147 }
148148 }
149149
150- std::tuple<Eigen::VectorXd, double > predict (const Eigen::VectorXd& x) const
150+ std::tuple<Eigen::VectorXd, Eigen::VectorXd > predict (const Eigen::VectorXd& x, bool compute_variance = true ) const
151151 {
152- Eigen::VectorXd ms;
153- Eigen::VectorXd ss;
154- std::tie (ms, ss) = predictm (x);
155- return std::make_tuple (ms, ss.mean ());
156- }
157-
158- std::tuple<Eigen::VectorXd, Eigen::VectorXd> predictm (const Eigen::VectorXd& x) const
159- {
160- Eigen::VectorXd ms (_gp_model.dim_out ());
161- Eigen::VectorXd ss (_gp_model.dim_out ());
162-
163- return _gp_model.query (x);
152+ if (compute_variance)
153+ return _gp_model.query (x);
154+ return std::make_tuple (_gp_model.mu (x), Eigen::VectorXd::Zero (_gp_model.dim_out ()));
164155 }
165156
166157 Eigen::MatrixXd samples () const
Original file line number Diff line number Diff line change @@ -112,7 +112,7 @@ namespace blackdrops {
112112 }
113113 }
114114
115- std::tuple<Eigen::VectorXd, Eigen::VectorXd> predictm (const Eigen::VectorXd& x) const
115+ std::tuple<Eigen::VectorXd, Eigen::VectorXd> predict (const Eigen::VectorXd& x, bool ) const
116116 {
117117 Eigen::VectorXd mu = _mean (x, x);
118118 Eigen::VectorXd ss = Eigen::VectorXd::Zero (mu.size ());
@@ -142,6 +142,6 @@ namespace blackdrops {
142142 return limbo::opt::no_grad (-mse);
143143 }
144144 };
145- }
145+ } // namespace blackdrops
146146
147147#endif
Original file line number Diff line number Diff line change @@ -189,7 +189,7 @@ namespace blackdrops {
189189
190190 Eigen::VectorXd mu;
191191 Eigen::VectorXd sigma;
192- std::tie (mu, sigma) = model.predictm (query_vec);
192+ std::tie (mu, sigma) = model.predict (query_vec);
193193
194194 Eigen::VectorXd final = init_diff + mu;
195195
@@ -229,7 +229,7 @@ namespace blackdrops {
229229
230230 Eigen::VectorXd mu;
231231 Eigen::VectorXd sigma;
232- std::tie (mu, sigma) = model.predictm (query_vec);
232+ std::tie (mu, sigma) = model.predict (query_vec);
233233
234234 if (Params::blackdrops::stochastic ()) {
235235 sigma = sigma.array ().sqrt ();
Original file line number Diff line number Diff line change @@ -163,7 +163,7 @@ namespace blackdrops {
163163
164164 Eigen::VectorXd mu;
165165 Eigen::VectorXd sigma;
166- std::tie (mu, sigma) = model.predictm (query_vec);
166+ std::tie (mu, sigma) = model.predict (query_vec);
167167
168168 Eigen::VectorXd final = init_diff + mu;
169169
@@ -203,7 +203,7 @@ namespace blackdrops {
203203
204204 Eigen::VectorXd mu;
205205 Eigen::VectorXd sigma;
206- std::tie (mu, sigma) = model.predictm (query_vec);
206+ std::tie (mu, sigma) = model.predict (query_vec, Params::blackdrops::stochastic () );
207207
208208 if (Params::blackdrops::stochastic ()) {
209209 sigma = sigma.array ().sqrt ();
You can’t perform that action at this time.
0 commit comments