Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tmva/tmva/inc/LinkDef4.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,9 @@
#pragma link C++ function TMVA::CreateVariableTransform;
#pragma link C++ function TMVA::DataLoaderCopy;

#pragma link C++ function TMVA::DataLoaderCopy;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also export the DataLoaderMP copy here?


#pragma link C++ class std::map<TString, Double_t>+;


#endif
16 changes: 10 additions & 6 deletions tmva/tmva/inc/TMVA/Config.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// @(#)root/tmva $Id$
// @(#)root/tmva $Id$
// Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss

/**********************************************************************************
Expand Down Expand Up @@ -47,7 +47,7 @@ namespace TMVA {
class MsgLogger;

class Config {

public:

static Config& Instance();
Expand All @@ -65,6 +65,9 @@ namespace TMVA {
Bool_t DrawProgressBar() const { return fDrawProgressBar; }
void SetDrawProgressBar( Bool_t d ) { fDrawProgressBar = d; }

UInt_t NWorkers() const { return fNWorkers; }
void SetNWorkers (UInt_t n) { fNWorkers = n; }

public:

class VariablePlotting;
Expand Down Expand Up @@ -97,8 +100,7 @@ namespace TMVA {
TString fWeightFileExtension;
TString fOptionsReferenceFileDir;
} fIONames; // Customisable weight file properties



private:

// private constructor
Expand All @@ -110,23 +112,25 @@ namespace TMVA {
static std::atomic<Config*> fgConfigPtr;
#else
static Config* fgConfigPtr;
#endif
#endif
private:

#if __cplusplus > 199711L
std::atomic<Bool_t> fUseColoredConsole; // coloured standard output
std::atomic<Bool_t> fSilent; // no output at all
std::atomic<Bool_t> fWriteOptionsReference; // if set true: Configurable objects write file with option reference
std::atomic<Bool_t> fDrawProgressBar; // draw progress bar to indicate training evolution
std::atomic<UInt_t> fNWorkers;
#else
Bool_t fUseColoredConsole; // coloured standard output
Bool_t fSilent; // no output at all
Bool_t fWriteOptionsReference; // if set true: Configurable objects write file with option reference
Bool_t fDrawProgressBar; // draw progress bar to indicate training evolution
UInt_t fNWorkers;
#endif
mutable MsgLogger* fLogger; // message logger
MsgLogger& Log() const { return *fLogger; }

ClassDef(Config,0); // Singleton class for global configuration settings
};

Expand Down
41 changes: 21 additions & 20 deletions tmva/tmva/inc/TMVA/DataLoader.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@
#include <vector>
#include <map>
#include "TCut.h"
#include <memory>

#include "TMVA/Factory.h"
#include "TMVA/Types.h"
#include "TMVA/DataSet.h"
#include "TFile.h"

class TFile;
class TTree;
class TDirectory;
class TH2;
Expand Down Expand Up @@ -84,16 +85,15 @@ namespace TMVA {
// special case: signal/background

// Data input related
void SetInputTrees( const TString& signalFileName, const TString& backgroundFileName,
void SetInputTrees(const TString& signalFileName, const TString& backgroundFileName,
Double_t signalWeight=1.0, Double_t backgroundWeight=1.0 );
void SetInputTrees( TTree* inputTree, const TCut& SigCut, const TCut& BgCut );
void SetInputTrees( TTree *inputTree, const TCut& SigCut, const TCut& BgCut );
// Set input trees at once
void SetInputTrees( TTree* signal, TTree* background,
Double_t signalWeight=1.0, Double_t backgroundWeight=1.0) ;
void SetInputTrees( TTree *signal, TTree* background, Double_t signalWeight=1.0, Double_t backgroundWeight=1.0) ;

void AddSignalTree( TTree* signal, Double_t weight=1.0, Types::ETreeType treetype = Types::kMaxTreeType );
void AddSignalTree( TTree *signal, Double_t weight=1.0, Types::ETreeType treetype = Types::kMaxTreeType );
void AddSignalTree( TString datFileS, Double_t weight=1.0, Types::ETreeType treetype = Types::kMaxTreeType );
void AddSignalTree( TTree* signal, Double_t weight, const TString& treetype );
void AddSignalTree( TTree *signal, Double_t weight, const TString& treetype );

// ... depreciated, kept for backwards compatibility
void SetSignalTree( TTree* signal, Double_t weight=1.0);
Expand All @@ -109,9 +109,9 @@ namespace TMVA {
void SetBackgroundWeightExpression( const TString& variable );

// special case: regression
void AddRegressionTree( TTree* tree, Double_t weight = 1.0,
Types::ETreeType treetype = Types::kMaxTreeType ) {
AddTree( tree, "Regression", weight, "", treetype );
void AddRegressionTree( TTree* tree, Double_t weight = 1.0,
Types::ETreeType treetype = Types::kMaxTreeType ) {
AddTree( tree, "Regression", weight, "", treetype );
}

// general
Expand Down Expand Up @@ -153,10 +153,10 @@ namespace TMVA {
void PrepareTrainingAndTestTree( const TCut& cut, const TString& splitOpt );
void PrepareTrainingAndTestTree( TCut sigcut, TCut bkgcut, const TString& splitOpt );

// ... deprecated, kept for backwards compatibility
// ... deprecated, kept for backwards compatibility
void PrepareTrainingAndTestTree( const TCut& cut, Int_t Ntrain, Int_t Ntest = -1 );

void PrepareTrainingAndTestTree( const TCut& cut, Int_t NsigTrain, Int_t NbkgTrain, Int_t NsigTest, Int_t NbkgTest,
void PrepareTrainingAndTestTree( const TCut& cut, Int_t NsigTrain, Int_t NbkgTrain, Int_t NsigTest, Int_t NbkgTest,
const TString& otherOpt="SplitMode=Random:!V" );

void PrepareTrainingAndTestTree( int foldNumber, Types::ETreeType tt );
Expand All @@ -168,15 +168,15 @@ namespace TMVA {
const DataSetInfo& GetDefaultDataSetInfo(){ return DefaultDataSetInfo(); }

TH2* GetCorrelationMatrix(const TString& className);

//Copy method use in VI and CV DEPRECATED: you can just call Clone DataLoader *dl2=(DataLoader *)dl1->Clone("dl2")
DataLoader* MakeCopy(TString name);
friend void DataLoaderCopy(TMVA::DataLoader* des, TMVA::DataLoader* src);
friend void DataLoaderCopy(TMVA::DataLoader* des, TMVA::DataLoader* src);
DataInputHandler& DataInput() { return *fDataInputHandler; }

private:


DataSetInfo& DefaultDataSetInfo();
void SetInputTreesFromEventAssignTrees();

Expand All @@ -188,7 +188,7 @@ namespace TMVA {

DataSetManager* fDataSetManager; // DSMTEST


DataInputHandler* fDataInputHandler;//->

std::vector<TMVA::VariableTransformBase*> fDefaultTrfs; // list of transformations on default DataSet
Expand All @@ -199,7 +199,7 @@ namespace TMVA {
Bool_t fVerbose; // verbose mode

// flag determining the way training and test data are assigned to DataLoader
enum DataAssignType { kUndefined = 0,
enum DataAssignType { kUndefined = 0,
kAssignTrees,
kAssignEvents };
DataAssignType fDataAssignType; // flags for data assigning
Expand All @@ -216,7 +216,7 @@ namespace TMVA {
Int_t fATreeType = 0; // type of event (=classIndex)
Float_t fATreeWeight = 0.0; // weight of the event
std::vector<Float_t> fATreeEvent; // event variables

Types::EAnalysisType fAnalysisType; // the training type

Bool_t fMakeFoldDataSet; // flag telling if the DataSet folds have been done
Expand All @@ -226,7 +226,8 @@ namespace TMVA {
ClassDef(DataLoader,3);
};
void DataLoaderCopy(TMVA::DataLoader* des, TMVA::DataLoader* src);
std::vector<std::shared_ptr<TFile>> DataLoaderCopyMP(TMVA::DataLoader *des, TMVA::DataLoader *src);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these need to be public? Would it be better to use private with Envelope as a friend? I think the intention is that only the Envelope should be able to use it. I would then propose that we enforce this. We can always open the interface up later, but not the other way around.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mammadhajili could you answer the comments of @ashlaban please?

void DataLoaderCopyMPCloseFiles(std::vector<std::shared_ptr<TFile>> files);
} // namespace TMVA

#endif

10 changes: 5 additions & 5 deletions tmva/tmva/src/Config.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,12 @@ TMVA::Config& TMVA::gConfig() { return TMVA::Config::Instance(); }
/// constructor - set defaults

TMVA::Config::Config() :
fUseColoredConsole ( kTRUE ),
fSilent ( kFALSE ),
fUseColoredConsole(kTRUE),
fSilent(kFALSE),
fWriteOptionsReference( kFALSE ),
fDrawProgressBar ( kFALSE ),
fLogger ( new MsgLogger("Config") )
fDrawProgressBar(kFALSE),
fNWorkers(1),
fLogger(new MsgLogger("Config"))
{
// plotting
fVariablePlotting.fTimesRMS = 8.0;
Expand Down Expand Up @@ -113,4 +114,3 @@ TMVA::Config& TMVA::Config::Instance()
return fgConfigPtr ? *fgConfigPtr :*(fgConfigPtr = new Config());
#endif
}

66 changes: 45 additions & 21 deletions tmva/tmva/src/CrossValidation.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "TMVA/ResultsClassification.h"
#include "TMVA/tmvaglob.h"
#include "TMVA/Types.h"
#include "ROOT/TProcessExecutor.hxx"

#include "TSystem.h"
#include "TAxis.h"
Expand All @@ -20,6 +21,7 @@

#include <iostream>
#include <memory>
using namespace std;

/*! \class TMVA::CrossValidationResult
\ingroup TMVA
Expand Down Expand Up @@ -126,16 +128,17 @@ void TMVA::CrossValidation::Evaluate()
fFoldStatus=kTRUE;
}

// Process K folds
for(UInt_t i=0; i<fNumFolds; ++i){
Log() << kDEBUG << "Fold (" << methodTitle << "): " << i << Endl;
auto workItem = [&](UInt_t workerID) {

Log() << kDEBUG << "Fold (" << methodTitle << "): " << workerID << Endl;
// Get specific fold of dataset and setup method
TString foldTitle = methodTitle;
foldTitle += "_fold";
foldTitle += i+1;

fDataLoader->PrepareFoldDataSet(i, TMVA::Types::kTesting);
MethodBase* smethod = fClassifier->BookMethod(fDataLoader.get(), methodName, methodTitle, methodOptions);
foldTitle += workerID + 1;
auto classifier = std::unique_ptr<Factory>(new TMVA::Factory(
"CrossValidation","!V:!ROC:Silent:!ModelPersistence:!Color:!DrawProgressBar:AnalysisType=Classification"));
fDataLoader->PrepareFoldDataSet(workerID, TMVA::Types::kTesting);
MethodBase *smethod = classifier->BookMethod(fDataLoader.get(), methodName, methodTitle, methodOptions);

// Train method
Event::SetIsTraining(kTRUE);
Expand All @@ -147,10 +150,10 @@ void TMVA::CrossValidation::Evaluate()
smethod->TestClassification();

// Store results
fResults.fROCs[i] = fClassifier->GetROCIntegral(fDataLoader->GetName(),methodTitle);
auto res = classifier->GetROCIntegral(fDataLoader->GetName(), methodTitle);

TGraph* gr = fClassifier->GetROCCurve(fDataLoader->GetName(), methodTitle, true);
gr->SetLineColor(i+1);
TGraph* gr = classifier->GetROCCurve(fDataLoader->GetName(), methodTitle, true);
gr->SetLineColor(workerID + 1);
gr->SetLineWidth(2);
gr->SetTitle(foldTitle.Data());
fResults.fROCCurves->Add(gr);
Expand All @@ -159,24 +162,45 @@ void TMVA::CrossValidation::Evaluate()
fResults.fSeps.push_back(smethod->GetSeparation());

Double_t err;
fResults.fEff01s.push_back(smethod->GetEfficiency("Efficiency:0.01",Types::kTesting, err));
fResults.fEff10s.push_back(smethod->GetEfficiency("Efficiency:0.10",Types::kTesting,err));
fResults.fEff30s.push_back(smethod->GetEfficiency("Efficiency:0.30",Types::kTesting,err));
fResults.fEffAreas.push_back(smethod->GetEfficiency("" ,Types::kTesting,err));
fResults.fEff01s.push_back(smethod->GetEfficiency("Efficiency:0.01", Types::kTesting, err));
fResults.fEff10s.push_back(smethod->GetEfficiency("Efficiency:0.10", Types::kTesting, err));
fResults.fEff30s.push_back(smethod->GetEfficiency("Efficiency:0.30", Types::kTesting, err));
fResults.fEffAreas.push_back(smethod->GetEfficiency("" , Types::kTesting, err));
fResults.fTrainEff01s.push_back(smethod->GetTrainingEfficiency("Efficiency:0.01"));
fResults.fTrainEff10s.push_back(smethod->GetTrainingEfficiency("Efficiency:0.10"));
fResults.fTrainEff30s.push_back(smethod->GetTrainingEfficiency("Efficiency:0.30"));

// Clean-up for this fold
smethod->Data()->DeleteResults(smethod->GetMethodName(), Types::kTesting, Types::kClassification);
smethod->Data()->DeleteResults(smethod->GetMethodName(), Types::kTraining, Types::kClassification);
fClassifier->DeleteAllMethods();
fClassifier->fMethodsMap.clear();
}

TMVA::gConfig().SetSilent(kFALSE);
Log() << kINFO << "Evaluation done." << Endl;
TMVA::gConfig().SetSilent(kTRUE);
classifier->DeleteAllMethods();
classifier->fMethodsMap.clear();

return make_pair(res, workerID);
};
vector<pair<double, UInt_t>> res;

auto nWorkers = TMVA::gConfig().NWorkers();

if(nWorkers > 1) {
ROOT::TProcessExecutor workers(nWorkers);
res = workers.Map(workItem, ROOT::TSeqI(fNumFolds));
}

else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the structure

if () {
   ...
} else {
   ...
}

for(UInt_t i = 0; i < fNumFolds; ++ i) {
auto res_pair = workItem(i);
res.push_back(res_pair);
}
}

for(auto res_pair: res) {
fResults.fROCs[res_pair.second] = res_pair.first;
}

TMVA::gConfig().SetSilent(kFALSE);
Log() << kINFO << "Evaluation done." << Endl;
TMVA::gConfig().SetSilent(kTRUE);
}

const TMVA::CrossValidationResult& TMVA::CrossValidation::GetResults() const {
Expand Down
Loading