Skip to content

Commit 837cbb1

Browse files
committed
Minor cleaning in models
1 parent 4f3979b commit 837cbb1

File tree

4 files changed

+10
-19
lines changed

4 files changed

+10
-19
lines changed

include/blackdrops/gp_model.hpp

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -147,20 +147,11 @@ namespace blackdrops {
147147
}
148148
}
149149

150-
std::tuple<Eigen::VectorXd, double> predict(const Eigen::VectorXd& x) const
150+
std::tuple<Eigen::VectorXd, Eigen::VectorXd> predict(const Eigen::VectorXd& x, bool compute_variance = true) const
151151
{
152-
Eigen::VectorXd ms;
153-
Eigen::VectorXd ss;
154-
std::tie(ms, ss) = predictm(x);
155-
return std::make_tuple(ms, ss.mean());
156-
}
157-
158-
std::tuple<Eigen::VectorXd, Eigen::VectorXd> predictm(const Eigen::VectorXd& x) const
159-
{
160-
Eigen::VectorXd ms(_gp_model.dim_out());
161-
Eigen::VectorXd ss(_gp_model.dim_out());
162-
163-
return _gp_model.query(x);
152+
if (compute_variance)
153+
return _gp_model.query(x);
154+
return std::make_tuple(_gp_model.mu(x), Eigen::VectorXd::Zero(_gp_model.dim_out()));
164155
}
165156

166157
Eigen::MatrixXd samples() const

include/blackdrops/mi_model.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ namespace blackdrops {
112112
}
113113
}
114114

115-
std::tuple<Eigen::VectorXd, Eigen::VectorXd> predictm(const Eigen::VectorXd& x) const
115+
std::tuple<Eigen::VectorXd, Eigen::VectorXd> predict(const Eigen::VectorXd& x, bool) const
116116
{
117117
Eigen::VectorXd mu = _mean(x, x);
118118
Eigen::VectorXd ss = Eigen::VectorXd::Zero(mu.size());
@@ -142,6 +142,6 @@ namespace blackdrops {
142142
return limbo::opt::no_grad(-mse);
143143
}
144144
};
145-
}
145+
} // namespace blackdrops
146146

147147
#endif

include/blackdrops/system/dart_system.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ namespace blackdrops {
189189

190190
Eigen::VectorXd mu;
191191
Eigen::VectorXd sigma;
192-
std::tie(mu, sigma) = model.predictm(query_vec);
192+
std::tie(mu, sigma) = model.predict(query_vec);
193193

194194
Eigen::VectorXd final = init_diff + mu;
195195

@@ -229,7 +229,7 @@ namespace blackdrops {
229229

230230
Eigen::VectorXd mu;
231231
Eigen::VectorXd sigma;
232-
std::tie(mu, sigma) = model.predictm(query_vec);
232+
std::tie(mu, sigma) = model.predict(query_vec);
233233

234234
if (Params::blackdrops::stochastic()) {
235235
sigma = sigma.array().sqrt();

include/blackdrops/system/ode_system.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ namespace blackdrops {
163163

164164
Eigen::VectorXd mu;
165165
Eigen::VectorXd sigma;
166-
std::tie(mu, sigma) = model.predictm(query_vec);
166+
std::tie(mu, sigma) = model.predict(query_vec);
167167

168168
Eigen::VectorXd final = init_diff + mu;
169169

@@ -203,7 +203,7 @@ namespace blackdrops {
203203

204204
Eigen::VectorXd mu;
205205
Eigen::VectorXd sigma;
206-
std::tie(mu, sigma) = model.predictm(query_vec);
206+
std::tie(mu, sigma) = model.predict(query_vec, Params::blackdrops::stochastic());
207207

208208
if (Params::blackdrops::stochastic()) {
209209
sigma = sigma.array().sqrt();

0 commit comments

Comments
 (0)