Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 274 additions & 0 deletions tmva/tmva/src/Factory.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
#include "TText.h"
#include "TLegend.h"
#include "TGraph.h"
#include "TMultiGraph.h"
#include "TSpline.h"
#include "TStyle.h"
#include "TMatrixF.h"
#include "TMatrixDSym.h"
Expand Down Expand Up @@ -98,6 +100,7 @@

#include <TCanvas.h>


const Int_t MinNoTrainingEvents = 10;
//const Int_t MinNoTestEvents = 1;
TFile* TMVA::Factory::fgTargetFile = 0;
Expand Down Expand Up @@ -2078,6 +2081,7 @@ float TMVA::Factory::CrossValidate(DataLoader * loader, Types::EMVA theMethod, T
EvaluateAllMethods();

ROCs.push_back(GetROCIntegral(seedloader->GetName(), methodTitle));
// here is where i have to GET the ROCCurve

TMVA::MethodBase * smethod = dynamic_cast<TMVA::MethodBase*>(fMethodsMap[seedloader->GetName()][0][0]);
TMVA::ResultsClassification * sresults = (TMVA::ResultsClassification*)smethod->Data()->GetResults(smethod->GetMethodName(), Types::kTesting, Types::kClassification);
Expand Down Expand Up @@ -2116,5 +2120,275 @@ float TMVA::Factory::CrossValidate(DataLoader * loader, Types::EMVA theMethod, T
}

return sumFOM/(double)NumFolds;


}


// Histogram of ROC Integrals

void TMVA::Factory::CrossValidateAUCHisto(DataLoader * loader, Types::EMVA theMethod, TString methodTitle, const char *theOption, bool optParams, int NumFolds, bool remakeDataSet)
{

Double_t AUC[NumFolds];
Double_t Y[NumFolds];
for(Int_t i = 0; i < NumFolds; i++){
Y[i] = i + 1.00;
}

//bool optParams = true;

if(remakeDataSet){
//loader->ValidationKFoldSet();
loader->MakeKFoldDataSet(NumFolds);
}

std::vector<float> parameterPerformance;

const int nbits = loader->DefaultDataSetInfo().GetNVariables();
std::vector<TString> varNames = loader->DefaultDataSetInfo().GetListOfVariables();

if(optParams){

std::vector<std::map<TString,Double_t> > foldParameters;

//loader->ValidationKFoldSet();

for(Int_t i=0; i<NumFolds; ++i){
Event::SetIsTraining(kTRUE);
TString optTitle = methodTitle;
optTitle += "_opt";
optTitle += i;

loader->PrepareTrainingAndTestTree(i, TMVA::Types::kTraining);

TMVA::DataLoader * seedloader = new TMVA::DataLoader(optTitle);

for(int index = 0; index<nbits; index++){
seedloader->AddVariable(varNames.at(index), 'F');
}

VIDataLoaderCopy(seedloader,loader);

MethodBase* mva = BookMethod(seedloader, theMethod, methodTitle, theOption);

foldParameters.push_back(mva->OptimizeTuningParameters("ROCIntegral","Minuit"));

this->DeleteAllMethods();

fMethodsMap.clear();
}

TString optionsString;

for(UInt_t t=0; t<foldParameters.size(); t++){
optionsString = theOption;
optionsString += ":";
std::map<TString,Double_t>::iterator it;
for(it=foldParameters.at(t).begin(); it!=foldParameters.at(t).end(); it++){
optionsString += it->first;
optionsString += "=";
optionsString += it->second;
if(it!=--foldParameters.at(t).end()){ optionsString += ":"; }
}
parameterPerformance.push_back(CrossValidate(loader, theMethod, methodTitle, optionsString, false, false));
}
}

std::vector<float> ROCs;

if(!optParams){

for(Int_t j=0; j<NumFolds; ++j){
TString foldTitle = methodTitle;
foldTitle += "_fold";
foldTitle += j+1;

loader->PrepareTrainingAndTestTree(j, TMVA::Types::kTesting);

TMVA::DataLoader * seedloader = new TMVA::DataLoader(foldTitle);

for(int index = 0; index<nbits; index++){
seedloader->AddVariable(varNames.at(index), 'F');
}

VIDataLoaderCopy(seedloader,loader);

BookMethod(seedloader, theMethod, methodTitle, theOption);

TrainAllMethods();
TestAllMethods();
EvaluateAllMethods();

ROCs.push_back(GetROCIntegral(seedloader->GetName(), methodTitle));

TMVA::MethodBase * smethod = dynamic_cast<TMVA::MethodBase*>(fMethodsMap[seedloader->GetName()][0][0]);
TMVA::ResultsClassification * sresults = (TMVA::ResultsClassification*)smethod->Data()->GetResults(smethod->GetMethodName(), Types::kTesting, Types::kClassification);
sresults->Clear();
sresults->Delete();
delete sresults;
fgTargetFile->cd();
fgTargetFile->Delete(seedloader->GetName());
fgTargetFile->Delete(Form("%s;1",seedloader->GetName()));
fgTargetFile->Flush();
gSystem->Exec(Form("rm -rf %s", seedloader->GetName()));

this->DeleteAllMethods();

fMethodsMap.clear();
}
}

float sumFOM = 0.0;

for(UInt_t k=0; k<ROCs.size(); ++k){
sumFOM += ROCs.at(k);
}

if(optParams){
for(UInt_t t=0; t<parameterPerformance.size(); ++t){
std::cout << "Parameters " << t+1 << " performance: " << parameterPerformance.at(t) << std::endl;
}
}
else{
for(UInt_t l=0; l<ROCs.size(); ++l){
AUC[l] = ROCs.at(l);
}
}

Double_t min = 0.0;
Double_t max = 0.0;
min = AUC[0];
max = AUC[0];

for(Int_t i = 1; i < NumFolds; i++){
if(AUC[i] > max){ max = AUC[i]; }
if(AUC[i] < min){ min = AUC[i]; }
}
TCanvas c ("c", "Histogram");
TH1D histo("AUC Distribution", "Histo", 100, min - 0.01, max + 0.01);
for(Int_t i = 0; i < NumFolds; i++){
histo.Fill(AUC[i], Y[i]);
}
histo.SetFillColor(90);
histo.SetTitle(" AUC Distribution ");
histo.Draw("HIST");
c.Print("AUCHisto.pdf");

}
// Different ROCs in a Graph

void TMVA::Factory::CrossValidatePlotROC(DataLoader * loader, Types::EMVA theMethod, TString methodTitle, const char *theOption, bool optParams, int NumFolds, bool remakeDataSet)
{

//bool optParams = true;

if(remakeDataSet){
//loader->ValidationKFoldSet();
loader->MakeKFoldDataSet(NumFolds);
}

std::vector<float> parameterPerformance;

const int nbits = loader->DefaultDataSetInfo().GetNVariables();
std::vector<TString> varNames = loader->DefaultDataSetInfo().GetListOfVariables();

if(optParams){

std::vector<std::map<TString,Double_t> > foldParameters;

//loader->ValidationKFoldSet();

for(Int_t i=0; i<NumFolds; ++i){
Event::SetIsTraining(kTRUE);
TString optTitle = methodTitle;
optTitle += "_opt";
optTitle += i;

loader->PrepareTrainingAndTestTree(i, TMVA::Types::kTraining);

TMVA::DataLoader * seedloader = new TMVA::DataLoader(optTitle);

for(int index = 0; index<nbits; index++){
seedloader->AddVariable(varNames.at(index), 'F');
}

VIDataLoaderCopy(seedloader,loader);

MethodBase* mva = BookMethod(seedloader, theMethod, methodTitle, theOption);

foldParameters.push_back(mva->OptimizeTuningParameters("ROCIntegral","Minuit"));

this->DeleteAllMethods();

fMethodsMap.clear();
}

TString optionsString;

for(UInt_t t=0; t<foldParameters.size(); t++){
optionsString = theOption;
optionsString += ":";
std::map<TString,Double_t>::iterator it;
for(it=foldParameters.at(t).begin(); it!=foldParameters.at(t).end(); it++){
optionsString += it->first;
optionsString += "=";
optionsString += it->second;
if(it!=--foldParameters.at(t).end()){ optionsString += ":"; }
}
parameterPerformance.push_back(CrossValidate(loader, theMethod, methodTitle, optionsString, false, false));
}
}

TCanvas *c = new TCanvas("c", "MultiGraph", 700, 500);
TMultiGraph *mg = new TMultiGraph();

if(!optParams){

for(Int_t j=0; j<NumFolds; ++j){
TString foldTitle = methodTitle;
foldTitle += "_fold";
foldTitle += j+1;

loader->PrepareTrainingAndTestTree(j, TMVA::Types::kTesting);

TMVA::DataLoader * seedloader = new TMVA::DataLoader(foldTitle);

for(int index = 0; index<nbits; index++){
seedloader->AddVariable(varNames.at(index), 'F');
}

VIDataLoaderCopy(seedloader,loader);

BookMethod(seedloader, theMethod, methodTitle, theOption);

TrainAllMethods();
TestAllMethods();
EvaluateAllMethods();

GetROCCurve(seedloader->GetName(), "BDT", true) -> SetLineColor(j*10);
mg -> Add(GetROCCurve(seedloader->GetName(), "BDT", true));

TMVA::MethodBase * smethod = dynamic_cast<TMVA::MethodBase*>(fMethodsMap[seedloader->GetName()][0][0]);
TMVA::ResultsClassification * sresults = (TMVA::ResultsClassification*)smethod->Data()->GetResults(smethod->GetMethodName(), Types::kTesting, Types::kClassification);
sresults->Clear();
sresults->Delete();
delete sresults;
fgTargetFile->cd();
fgTargetFile->Delete(seedloader->GetName());
fgTargetFile->Delete(Form("%s;1",seedloader->GetName()));
fgTargetFile->Flush();
gSystem->Exec(Form("rm -rf %s", seedloader->GetName()));

this->DeleteAllMethods();

fMethodsMap.clear();
}
}

mg -> Draw("al");
mg -> GetXaxis() -> SetTitle(" Signal Efficiency ");
mg -> GetYaxis() -> SetTitle(" Background Rejection ");
c -> Print("ROC.pdf");

}