Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
TMVA MultiProcessing
  • Loading branch information
mammadhajili authored and mammadhajili committed Aug 18, 2017
commit b191f1bac12f481a6a1af5c82a40ce79ff159445
5 changes: 5 additions & 0 deletions tmva/tmva/inc/LinkDef4.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,9 @@
#pragma link C++ function TMVA::CreateVariableTransform;
#pragma link C++ function TMVA::DataLoaderCopy;

#pragma link C++ function TMVA::DataLoaderCopy;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also export the DataLoaderMP copy here?


#pragma link C++ class std::map<TString, Double_t>+;


#endif
5 changes: 5 additions & 0 deletions tmva/tmva/inc/TMVA/Config.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ namespace TMVA {

Bool_t DrawProgressBar() const { return fDrawProgressBar; }
void SetDrawProgressBar( Bool_t d ) { fDrawProgressBar = d; }

UInt_t NWorkers() const {return fNWorkers; }
void SetNWorkers (UInt_t n) {fNWorkers = n; }

public:

Expand Down Expand Up @@ -118,11 +121,13 @@ namespace TMVA {
std::atomic<Bool_t> fSilent; // no output at all
std::atomic<Bool_t> fWriteOptionsReference; // if set true: Configurable objects write file with option reference
std::atomic<Bool_t> fDrawProgressBar; // draw progress bar to indicate training evolution
std::atomic<UInt_t> fNWorkers; // number of workers in multiprocessing parallelization
#else
Bool_t fUseColoredConsole; // coloured standard output
Bool_t fSilent; // no output at all
Bool_t fWriteOptionsReference; // if set true: Configurable objects write file with option reference
Bool_t fDrawProgressBar; // draw progress bar to indicate training evolution
UInt_t fNWorkers;
#endif
mutable MsgLogger* fLogger; // message logger
MsgLogger& Log() const { return *fLogger; }
Expand Down
6 changes: 5 additions & 1 deletion tmva/tmva/inc/TMVA/DataLoader.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@
#include <vector>
#include <map>
#include "TCut.h"
#include <memory>

#include "TMVA/Factory.h"
#include "TMVA/Types.h"
#include "TMVA/DataSet.h"
#include "TFile.h"

class TFile;
//class TFile;
class TTree;
class TDirectory;
class TH2;
Expand Down Expand Up @@ -226,6 +228,8 @@ namespace TMVA {
ClassDef(DataLoader,3);
};
void DataLoaderCopy(TMVA::DataLoader* des, TMVA::DataLoader* src);
std::vector < std::shared_ptr<TFile> > DataLoaderCopyMP(TMVA::DataLoader *des, TMVA::DataLoader *src);
void DataLoaderCopyMPCloseFiles(std::vector< std::shared_ptr<TFile> > files);
} // namespace TMVA

#endif
Expand Down
2 changes: 1 addition & 1 deletion tmva/tmva/src/Config.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ TMVA::Config::Config() :
fSilent ( kFALSE ),
fWriteOptionsReference( kFALSE ),
fDrawProgressBar ( kFALSE ),
fNWorkers ( 1 ),
fLogger ( new MsgLogger("Config") )
{
// plotting
Expand Down Expand Up @@ -113,4 +114,3 @@ TMVA::Config& TMVA::Config::Instance()
return fgConfigPtr ? *fgConfigPtr :*(fgConfigPtr = new Config());
#endif
}

60 changes: 46 additions & 14 deletions tmva/tmva/src/CrossValidation.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include "TMVA/ResultsClassification.h"
#include "TMVA/tmvaglob.h"
#include "TMVA/Types.h"
#include "ROOT/TProcessExecutor.hxx"


#include "TSystem.h"
#include "TAxis.h"
Expand All @@ -20,6 +22,10 @@

#include <iostream>
#include <memory>
using namespace std;

//const UInt_t nWorkers = 2U;


/*! \class TMVA::CrossValidationResult
\ingroup TMVA
Expand Down Expand Up @@ -126,16 +132,17 @@ void TMVA::CrossValidation::Evaluate()
fFoldStatus=kTRUE;
}

// Process K folds
for(UInt_t i=0; i<fNumFolds; ++i){
Log() << kDEBUG << "Fold (" << methodTitle << "): " << i << Endl;

auto workItem = [&](UInt_t workerID) {

Log() << kDEBUG << "Fold (" << methodTitle << "): " << workerID << Endl;
// Get specific fold of dataset and setup method
TString foldTitle = methodTitle;
foldTitle += "_fold";
foldTitle += i+1;

fDataLoader->PrepareFoldDataSet(i, TMVA::Types::kTesting);
MethodBase* smethod = fClassifier->BookMethod(fDataLoader.get(), methodName, methodTitle, methodOptions);
foldTitle += workerID+1;
auto classifier = std::unique_ptr<Factory>(new TMVA::Factory("CrossValidation","!V:!ROC:Silent:!ModelPersistence:!Color:!DrawProgressBar:AnalysisType=Classification"));
fDataLoader->PrepareFoldDataSet(workerID, TMVA::Types::kTesting);
MethodBase* smethod = classifier->BookMethod(fDataLoader.get(), methodName, methodTitle, methodOptions);

// Train method
Event::SetIsTraining(kTRUE);
Expand All @@ -147,10 +154,11 @@ void TMVA::CrossValidation::Evaluate()
smethod->TestClassification();

// Store results
fResults.fROCs[i] = fClassifier->GetROCIntegral(fDataLoader->GetName(),methodTitle);
auto res = classifier->GetROCIntegral(fDataLoader->GetName(),methodTitle);
//fResults.fROCs[workerID] = classifier->GetROCIntegral(fDataLoader->GetName(),methodTitle);

TGraph* gr = fClassifier->GetROCCurve(fDataLoader->GetName(), methodTitle, true);
gr->SetLineColor(i+1);
TGraph* gr = classifier->GetROCCurve(fDataLoader->GetName(), methodTitle, true);
gr->SetLineColor(workerID+1);
gr->SetLineWidth(2);
gr->SetTitle(foldTitle.Data());
fResults.fROCCurves->Add(gr);
Expand All @@ -159,7 +167,7 @@ void TMVA::CrossValidation::Evaluate()
fResults.fSeps.push_back(smethod->GetSeparation());

Double_t err;
fResults.fEff01s.push_back(smethod->GetEfficiency("Efficiency:0.01",Types::kTesting, err));
fResults.fEff01s.push_back(smethod->GetEfficiency("Efficiency:0.01",Types::kTesting,err));
fResults.fEff10s.push_back(smethod->GetEfficiency("Efficiency:0.10",Types::kTesting,err));
fResults.fEff30s.push_back(smethod->GetEfficiency("Efficiency:0.30",Types::kTesting,err));
fResults.fEffAreas.push_back(smethod->GetEfficiency("" ,Types::kTesting,err));
Expand All @@ -170,9 +178,33 @@ void TMVA::CrossValidation::Evaluate()
// Clean-up for this fold
smethod->Data()->DeleteResults(smethod->GetMethodName(), Types::kTesting, Types::kClassification);
smethod->Data()->DeleteResults(smethod->GetMethodName(), Types::kTraining, Types::kClassification);
fClassifier->DeleteAllMethods();
fClassifier->fMethodsMap.clear();
}
classifier->DeleteAllMethods();
classifier->fMethodsMap.clear();

return make_pair(res, workerID);
};

vector< pair < double, UInt_t > > res;

auto nWorkers = TMVA::gConfig().NWorkers();

if(nWorkers > 1) {
ROOT::TProcessExecutor workers(nWorkers);
res = workers.Map(workItem, ROOT::TSeqI(fNumFolds));
}

else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the structure

if () {
   ...
} else {
   ...
}


for(UInt_t i = 0; i < fNumFolds; ++ i) {
auto res_pair = workItem(i);
res.push_back(res_pair);
}

}

for(auto res_pair: res) {
fResults.fROCs[res_pair.second] = res_pair.first;
}

TMVA::gConfig().SetSilent(kFALSE);
Log() << kINFO << "Evaluation done." << Endl;
Expand Down
42 changes: 39 additions & 3 deletions tmva/tmva/src/DataLoader.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,7 @@ void TMVA::DataLoader::MakeKFoldDataSet(UInt_t numberFolds, bool validationSet){

void TMVA::DataLoader::PrepareFoldDataSet(UInt_t foldNumber, Types::ETreeType tt){


UInt_t numFolds = fTrainSigEvents.size();

std::vector<Event*>* tempTrain = new std::vector<Event*>;
Expand Down Expand Up @@ -853,17 +854,52 @@ TMVA::DataLoader* TMVA::DataLoader::MakeCopy(TString name)

void TMVA::DataLoaderCopy(TMVA::DataLoader* des, TMVA::DataLoader* src)
{

for( std::vector<TreeInfo>::const_iterator treeinfo=src->DataInput().Sbegin();treeinfo!=src->DataInput().Send();treeinfo++)
{
des->AddSignalTree( (*treeinfo).GetTree(), (*treeinfo).GetWeight(),(*treeinfo).GetTreeType());
des->AddSignalTree( (*treeinfo).GetTree(), (*treeinfo).GetWeight(),(*treeinfo).GetTreeType());
}

for( std::vector<TreeInfo>::const_iterator treeinfo=src->DataInput().Bbegin();treeinfo!=src->DataInput().Bend();treeinfo++)
{
des->AddBackgroundTree( (*treeinfo).GetTree(), (*treeinfo).GetWeight(),(*treeinfo).GetTreeType());
des->AddBackgroundTree( (*treeinfo).GetTree(), (*treeinfo).GetWeight(),(*treeinfo).GetTreeType());
}
}

std::vector< std::shared_ptr<TFile> > TMVA::DataLoaderCopyMP(TMVA::DataLoader *des, TMVA::DataLoader *src) {

std::vector< std::shared_ptr<TFile> > vec_files;
for(std::vector<TreeInfo>::const_iterator treeinfo_signal=src->DataInput().Sbegin(), treeinfo_back=src->DataInput().Bbegin();
treeinfo_signal!=src->DataInput().Send(), treeinfo_back!=src->DataInput().Bend();
treeinfo_signal++, treeinfo_back++)
{
TTree *stree = treeinfo_signal -> GetTree();
TTree *btree = treeinfo_back -> GetTree();

TString sfileName = stree->GetCurrentFile()->GetName();
TString bfileName = btree->GetCurrentFile()->GetName();
std::shared_ptr<TFile> sfile( TFile::Open(sfileName)) ;
std::shared_ptr<TFile> bfile;
if (bfileName != sfileName) {
bfile = std::shared_ptr<TFile>(TFile::Open(bfileName));
}
else {
bfile = sfile;
}
TTree* signalTree = (TTree*)sfile->Get(stree->GetName());
TTree* backgTree = (TTree*)bfile->Get(btree->GetName());
des->AddSignalTree(signalTree);
des->AddBackgroundTree(backgTree);

vec_files.push_back(sfile);
vec_files.push_back(bfile);
}
return vec_files;
}
void TMVA::DataLoaderCopyMPCloseFiles(std::vector<std::shared_ptr<TFile> > files) {
for(auto file: files) {
file->Close();
}
}
////////////////////////////////////////////////////////////////////////////////
/// returns the correlation matrix of datasets

Expand Down
44 changes: 31 additions & 13 deletions tmva/tmva/src/HyperParameterOptimisation.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
#include "TMultiGraph.h"
#include "TString.h"
#include "TSystem.h"
#include "ROOT/TProcessExecutor.hxx"

#include <iostream>
#include <memory>
#include <vector>

using namespace std;
/*! \class TMVA::HyperParameterOptimisationResult
\ingroup TMVA

Expand All @@ -29,6 +31,8 @@

*/

//const int nWorkers = 4U;

TMVA::HyperParameterOptimisationResult::HyperParameterOptimisationResult()
: fROCAVG(0.0), fROCCurves(std::make_shared<TMultiGraph>())
{
Expand Down Expand Up @@ -98,27 +102,41 @@ void TMVA::HyperParameterOptimisation::Evaluate()
fFoldStatus=kTRUE;
}
fResults.fMethodName = methodName;
auto workItem = [&](UInt_t workerID) {
TString foldTitle = methodTitle;
foldTitle += "_opt";
foldTitle += workerID+1;

for(UInt_t i = 0; i < fNumFolds; ++i) {

TString foldTitle = methodTitle;
foldTitle += "_opt";
foldTitle += i+1;
Event::SetIsTraining(kTRUE);
fDataLoader->PrepareFoldDataSet(workerID, TMVA::Types::kTraining);

Event::SetIsTraining(kTRUE);
fDataLoader->PrepareFoldDataSet(i, TMVA::Types::kTraining);
auto smethod = fClassifier->BookMethod(fDataLoader.get(), methodName, methodTitle, methodOptions);

auto smethod = fClassifier->BookMethod(fDataLoader.get(), methodName, methodTitle, methodOptions);
auto params=smethod->OptimizeTuningParameters(fFomType,fFitType);

auto params=smethod->OptimizeTuningParameters(fFomType,fFitType);
fResults.fFoldParameters.push_back(params);
//fResults.fFoldParameters.push_back(params);

smethod->Data()->DeleteResults(smethod->GetMethodName(), Types::kTraining, Types::kClassification);
smethod->Data()->DeleteResults(smethod->GetMethodName(), Types::kTraining, Types::kClassification);

fClassifier->DeleteAllMethods();
fClassifier->DeleteAllMethods();

fClassifier->fMethodsMap.clear();
fClassifier->fMethodsMap.clear();

return params;
};
vector < map<TString,Double_t> > res;
auto nWorkers = TMVA::gConfig().NWorkers();
if(nWorkers> 1) {
ROOT::TProcessExecutor workers(nWorkers);
res = workers.Map(workItem, ROOT::TSeqI(fNumFolds));
}
else {
for(UInt_t i = 0; i < fNumFolds; ++ i) {
res.push_back(workItem(i));
}
}
for(auto results: res) {
fResults.fFoldParameters.push_back(results);
}

}
Loading