Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Ontology: steepest gradient descent. Allow for small variations in th…
…e learning rate
  • Loading branch information
re-cursion committed Sep 21, 2016
commit 71aa5c217555f692bb9fbfce7c059ea9e1e48b80
14 changes: 7 additions & 7 deletions tmva/tmva/inc/TMVA/NeuralNet.icc
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ template <bool HasDropOut, typename ItSource, typename ItWeight, typename ItPrev
// plotWeights (localWeights);

double alpha = gaussDouble (m_alpha, m_alpha/2.0);
// double alpha = m_alpha;
// double alpha = m_alpha;

auto itG = begin (gradients);
auto itGEnd = end (gradients);
Expand Down Expand Up @@ -781,7 +781,7 @@ template <typename LAYERDATA>
size_t patternPerThread = testPattern.size () / numThreads;
std::vector<Batch> batches;
auto itPat = testPattern.begin ();
// auto itPatEnd = testPattern.end ();
// auto itPatEnd = testPattern.end ();
for (size_t idxThread = 0; idxThread < numThreads-1; ++idxThread)
{
batches.push_back (Batch (itPat, itPat + patternPerThread));
Expand All @@ -798,24 +798,24 @@ template <typename LAYERDATA>
std::async (std::launch::async, [&]()
{
std::vector<double> localOutput;
pass_through_type passThrough (settings, batch, dropContainerTest);
pass_through_type passThrough (settings, batch, dropContainerTest);
double testBatchError = (*this) (passThrough, weights, ModeOutput::FETCH, localOutput);
return std::make_tuple (testBatchError, localOutput);
})
);
}

auto itBatch = batches.begin ();
auto itBatch = batches.begin ();
for (auto& f : futures)
{
std::tuple<double,std::vector<double>> result = f.get ();
testError += std::get<0>(result) / batches.size ();
std::vector<double> output = std::get<1>(result);

//if (output.size () == testPattern.size ())
//if (output.size () == testPattern.size ())
{
//auto it = begin (testPattern);
auto it = (*itBatch).begin ();
//auto it = begin (testPattern);
auto it = (*itBatch).begin ();
for (double out : output)
{
settings.testSample (0, out, (*it).output ().at (0), (*it).weight ());
Expand Down