Skip to content
Merged
9 changes: 7 additions & 2 deletions tmva/tmva/inc/TMVA/Config.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ namespace TMVA {
void SetDrawProgressBar( Bool_t d ) { fDrawProgressBar = d; }
UInt_t GetNCpu() { return fNCpu; }

UInt_t GetNumWorkers() const { return fNWorkers; }
void SetNumWorkers(UInt_t n) { fNWorkers = n; }

#ifdef R__USE_IMT
ROOT::TThreadExecutor &GetThreadExecutor() { return fPool; }
#endif
Expand Down Expand Up @@ -127,15 +130,17 @@ namespace TMVA {
private:

#if __cplusplus > 199711L
std::atomic<Bool_t> fDrawProgressBar; // draw progress bar to indicate training evolution
std::atomic<UInt_t> fNWorkers; // Default number of workers for multi-process jobs
std::atomic<Bool_t> fUseColoredConsole; // coloured standard output
std::atomic<Bool_t> fSilent; // no output at all
std::atomic<Bool_t> fWriteOptionsReference; // if set true: Configurable objects write file with option reference
std::atomic<Bool_t> fDrawProgressBar; // draw progress bar to indicate training evolution
#else
Bool_t fDrawProgressBar; // draw progress bar to indicate training evolution
UInt_t fNWorkers; // Default number of workers for multi-process jobs
Bool_t fUseColoredConsole; // coloured standard output
Bool_t fSilent; // no output at all
Bool_t fWriteOptionsReference; // if set true: Configurable objects write file with option reference
Bool_t fDrawProgressBar; // draw progress bar to indicate training evolution
#endif
mutable MsgLogger* fLogger; // message logger
MsgLogger& Log() const { return *fLogger; }
Expand Down
36 changes: 32 additions & 4 deletions tmva/tmva/inc/TMVA/CrossValidation.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
#ifndef ROOT_TMVA_CROSS_EVALUATION
#define ROOT_TMVA_CROSS_EVALUATION

#include "TString.h"
#include "TGraph.h"
#include "TMultiGraph.h"
#include "TString.h"

#include "TMVA/IMethod.h"
#include "TMVA/Configurable.h"
Expand Down Expand Up @@ -47,6 +48,29 @@ using EventTypes_t = std::vector<Bool_t>;
using EventOutputs_t = std::vector<Float_t>;
using EventOutputsMulticlass_t = std::vector<std::vector<Float_t>>;

class CrossValidationFoldResult {
public:
CrossValidationFoldResult() {} // For multi-proc serialisation
CrossValidationFoldResult(UInt_t iFold)
: fFold(iFold)
{}

UInt_t fFold;

Float_t fROCIntegral;
TGraph fROC;

Double_t fSig;
Double_t fSep;
Double_t fEff01;
Double_t fEff10;
Double_t fEff30;
Double_t fEffArea;
Double_t fTrainEff01;
Double_t fTrainEff10;
Double_t fTrainEff30;
};

// Used internally to keep per-fold aggregate statistics
// such as ROC curves, ROC integrals and efficiencies.
class CrossValidationResult {
Expand All @@ -67,7 +91,7 @@ class CrossValidationResult {
std::vector<Double_t> fTrainEff30s;

public:
CrossValidationResult();
CrossValidationResult(UInt_t numFolds);
CrossValidationResult(const CrossValidationResult &);
~CrossValidationResult() { fROCCurves = nullptr; }

Expand All @@ -88,6 +112,9 @@ class CrossValidationResult {
std::vector<Double_t> GetTrainEff01Values() const { return fTrainEff01s; }
std::vector<Double_t> GetTrainEff10Values() const { return fTrainEff10s; }
std::vector<Double_t> GetTrainEff30Values() const { return fTrainEff30s; }

private:
void Fill(CrossValidationFoldResult const & fr);
};

class CrossValidation : public Envelope {
Expand All @@ -113,8 +140,7 @@ class CrossValidation : public Envelope {
void Evaluate();

private:
void ProcessFold(UInt_t iFold, UInt_t iMethod);
void MergeFolds();
CrossValidationFoldResult ProcessFold(UInt_t iFold, UInt_t iMethod);

Types::EAnalysisType fAnalysisType;
TString fAnalysisTypeStr;
Expand All @@ -125,6 +151,8 @@ class CrossValidation : public Envelope {
Bool_t fFoldStatus; //! If true: dataset is prepared
TString fJobName;
UInt_t fNumFolds; //! Number of folds to prepare
UInt_t fNumWorkerProcs; //! Number of processes to use for fold evaluation.
//!(Default, no parallel evaluation)
TString fOutputFactoryOptions;
TString fOutputEnsembling; //! How to combine output of individual folds
TFile *fOutputFile;
Expand Down
2 changes: 2 additions & 0 deletions tmva/tmva/inc/TMVA/DataSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ namespace TMVA {
void DeleteResults ( const TString &,
Types::ETreeType type,
Types::EAnalysisType analysistype );
void DeleteAllResults(Types::ETreeType type,
Types::EAnalysisType analysistype);

void SetVerbose( Bool_t ) {}

Expand Down
5 changes: 3 additions & 2 deletions tmva/tmva/src/Config.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,12 @@ TMVA::Config& TMVA::gConfig() { return TMVA::Config::Instance(); }
/// constructor - set defaults

TMVA::Config::Config() :
fDrawProgressBar ( kFALSE ),
fNWorkers (1),
fUseColoredConsole ( kTRUE ),
fSilent ( kFALSE ),
fWriteOptionsReference( kFALSE ),
fDrawProgressBar ( kFALSE ),
fLogger ( new MsgLogger("Config") )
fLogger (new MsgLogger("Config"))
{
// plotting
fVariablePlotting.fTimesRMS = 8.0;
Expand Down
113 changes: 94 additions & 19 deletions tmva/tmva/src/CrossValidation.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,54 @@
#include <memory>

//_______________________________________________________________________
TMVA::CrossValidationResult::CrossValidationResult():fROCCurves(new TMultiGraph())
TMVA::CrossValidationResult::CrossValidationResult(UInt_t numFolds)
:fROCCurves(new TMultiGraph())
{
fSigs.resize(numFolds);
fSeps.resize(numFolds);
fEff01s.resize(numFolds);
fEff10s.resize(numFolds);
fEff30s.resize(numFolds);
fEffAreas.resize(numFolds);
fTrainEff01s.resize(numFolds);
fTrainEff10s.resize(numFolds);
fTrainEff30s.resize(numFolds);
}

//_______________________________________________________________________
TMVA::CrossValidationResult::CrossValidationResult(const CrossValidationResult &obj)
{
fROCs=obj.fROCs;
fROCCurves = obj.fROCCurves;

fSigs = obj.fSigs;
fSeps = obj.fSeps;
fEff01s = obj.fEff01s;
fEff10s = obj.fEff10s;
fEff30s = obj.fEff30s;
fEffAreas = obj.fEffAreas;
fTrainEff01s = obj.fTrainEff01s;
fTrainEff10s = obj.fTrainEff10s;
fTrainEff30s = obj.fTrainEff30s;
}

//_______________________________________________________________________
void TMVA::CrossValidationResult::Fill(CrossValidationFoldResult const & fr)
{
UInt_t iFold = fr.fFold;

fROCs[iFold] = fr.fROCIntegral;
fROCCurves->Add(static_cast<TGraph *>(fr.fROC.Clone()));

fSigs[iFold] = fr.fSig;
fSeps[iFold] = fr.fSep;
fEff01s[iFold] = fr.fEff01;
fEff10s[iFold] = fr.fEff10;
fEff30s[iFold] = fr.fEff30;
fEffAreas[iFold] = fr.fEffArea;
fTrainEff01s[iFold] = fr.fTrainEff01;
fTrainEff10s[iFold] = fr.fTrainEff10;
fTrainEff30s[iFold] = fr.fTrainEff30;
}

//_______________________________________________________________________
Expand Down Expand Up @@ -150,6 +189,7 @@ TMVA::CrossValidation::CrossValidation(TString jobName, TMVA::DataLoader *datalo
fFoldStatus(kFALSE),
fJobName(jobName),
fNumFolds(2),
fNumWorkerProcs(1),
fOutputFactoryOptions(""),
fOutputFile(outputFile),
fSilent(kFALSE),
Expand Down Expand Up @@ -212,6 +252,11 @@ void TMVA::CrossValidation::InitOptions()
// Options specific to CE
DeclareOptionRef(fSplitExprString, "SplitExpr", "The expression used to assign events to folds");
DeclareOptionRef(fNumFolds, "NumFolds", "Number of folds to generate");
DeclareOptionRef(fNumWorkerProcs, "NumWorkerProcs",
"Determines how many processes to use for evaluation. 1 means no"
" parallelisation. 2 means use 2 processes. 0 means figure out the"
" number automatically based on the number of cpus available. Default"
" 1.");

DeclareOptionRef(fFoldFileOutput, "FoldFileOutput",
"If given a TMVA output file will be generated for each fold. Filename will be the same as "
Expand Down Expand Up @@ -342,7 +387,7 @@ void TMVA::CrossValidation::SetSplitExpr(TString splitExpr)
/// @param iFold fold to evaluate
///

void TMVA::CrossValidation::ProcessFold(UInt_t iFold, UInt_t iMethod)
TMVA::CrossValidationFoldResult TMVA::CrossValidation::ProcessFold(UInt_t iFold, UInt_t iMethod)
{
TString methodName = fMethods[iMethod].GetValue<TString>("MethodName");
TString methodTitle = fMethods[iMethod].GetValue<TString>("MethodTitle");
Expand Down Expand Up @@ -376,28 +421,30 @@ void TMVA::CrossValidation::ProcessFold(UInt_t iFold, UInt_t iMethod)
fFoldFactory->TestAllMethods();
fFoldFactory->EvaluateAllMethods();

TMVA::CrossValidationFoldResult result(iFold);

// Results for aggregation (ROC integral, efficiencies etc.)
if (fAnalysisType == Types::kClassification or fAnalysisType == Types::kMulticlass) {
fResults[iMethod].fROCs[iFold] = fFoldFactory->GetROCIntegral(fDataLoader->GetName(), foldTitle);
result.fROCIntegral = fFoldFactory->GetROCIntegral(fDataLoader->GetName(), foldTitle);

TGraph *gr = fFoldFactory->GetROCCurve(fDataLoader->GetName(), foldTitle, true);
gr->SetLineColor(iFold + 1);
gr->SetLineWidth(2);
gr->SetTitle(foldTitle.Data());
fResults[iMethod].fROCCurves->Add(gr);
result.fROC = *gr;

fResults[iMethod].fSigs.push_back(smethod->GetSignificance());
fResults[iMethod].fSeps.push_back(smethod->GetSeparation());
result.fSig = smethod->GetSignificance();
result.fSep = smethod->GetSeparation();

if (fAnalysisType == Types::kClassification) {
Double_t err;
fResults[iMethod].fEff01s.push_back(smethod->GetEfficiency("Efficiency:0.01", Types::kTesting, err));
fResults[iMethod].fEff10s.push_back(smethod->GetEfficiency("Efficiency:0.10", Types::kTesting, err));
fResults[iMethod].fEff30s.push_back(smethod->GetEfficiency("Efficiency:0.30", Types::kTesting, err));
fResults[iMethod].fEffAreas.push_back(smethod->GetEfficiency("", Types::kTesting, err));
fResults[iMethod].fTrainEff01s.push_back(smethod->GetTrainingEfficiency("Efficiency:0.01"));
fResults[iMethod].fTrainEff10s.push_back(smethod->GetTrainingEfficiency("Efficiency:0.10"));
fResults[iMethod].fTrainEff30s.push_back(smethod->GetTrainingEfficiency("Efficiency:0.30"));
result.fEff01 = smethod->GetEfficiency("Efficiency:0.01", Types::kTesting, err);
result.fEff10 = smethod->GetEfficiency("Efficiency:0.10", Types::kTesting, err);
result.fEff30 = smethod->GetEfficiency("Efficiency:0.30", Types::kTesting, err);
result.fEffArea = smethod->GetEfficiency("", Types::kTesting, err);
result.fTrainEff01 = smethod->GetTrainingEfficiency("Efficiency:0.01");
result.fTrainEff10 = smethod->GetTrainingEfficiency("Efficiency:0.10");
result.fTrainEff30 = smethod->GetTrainingEfficiency("Efficiency:0.30");
} else if (fAnalysisType == Types::kMulticlass) {
// Nothing here for now
}
Expand All @@ -410,12 +457,14 @@ void TMVA::CrossValidation::ProcessFold(UInt_t iFold, UInt_t iMethod)

// Clean-up for this fold
{
smethod->Data()->DeleteResults(foldTitle, Types::kTraining, smethod->GetAnalysisType());
smethod->Data()->DeleteResults(foldTitle, Types::kTesting, smethod->GetAnalysisType());
smethod->Data()->DeleteAllResults(Types::kTraining, smethod->GetAnalysisType());
smethod->Data()->DeleteAllResults(Types::kTesting, smethod->GetAnalysisType());
}

fFoldFactory->DeleteAllMethods();
fFoldFactory->fMethodsMap.clear();

return result;
}

////////////////////////////////////////////////////////////////////////////////
Expand All @@ -431,8 +480,9 @@ void TMVA::CrossValidation::Evaluate()
fFoldStatus = kTRUE;
}

fResults.resize(fMethods.size());
fResults.reserve(fMethods.size());
for (UInt_t iMethod = 0; iMethod < fMethods.size(); iMethod++) {
CrossValidationResult result{fNumFolds};

TString methodTypeName = fMethods[iMethod].GetValue<TString>("MethodName");
TString methodTitle = fMethods[iMethod].GetValue<TString>("MethodTitle");
Expand All @@ -445,10 +495,33 @@ void TMVA::CrossValidation::Evaluate()
Log() << kINFO << "Evaluate method: " << methodTitle << Endl;

// Process K folds
for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
ProcessFold(iFold, iMethod);
auto nWorkers = fNumWorkerProcs;
if (nWorkers == 1) {
// Fall back to global config
nWorkers = TMVA::gConfig().GetNumWorkers();
}
if (nWorkers == 1) {
for (UInt_t iFold = 0; iFold < fNumFolds; ++iFold) {
auto fold_result = ProcessFold(iFold, iMethod);
result.Fill(fold_result);
}
} else {
ROOT::TProcessExecutor workers(nWorkers);
std::vector<CrossValidationFoldResult> result_vector;

auto workItem = [this, iMethod](UInt_t iFold) {
return ProcessFold(iFold, iMethod);
};

result_vector = workers.Map(workItem, ROOT::TSeqI(fNumFolds));

for (auto && fold_result : result_vector) {
result.Fill(fold_result);
}
}

fResults.push_back(result);

// Serialise the cross evaluated method
TString options =
Form("SplitExpr=%s:NumFolds=%i"
Expand Down Expand Up @@ -478,7 +551,9 @@ void TMVA::CrossValidation::Evaluate()
IMethod *method_interface = fFactory->GetMethod(fDataLoader.get()->GetName(), methodTitle);
MethodCrossValidation *method = dynamic_cast<MethodCrossValidation *>(method_interface);

fFactory->WriteDataInformation(method->fDataSetInfo);
if (fOutputFile) {
fFactory->WriteDataInformation(method->fDataSetInfo);
}

Event::SetIsTraining(kTRUE);
method->TrainMethod();
Expand Down
Loading