Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add multiclass ROC curves for TMVAMulticlassGui
Introduces a new (for multiclass anyway) button in the gui that
when clicked displays one ROC curve per class. Each curve contains
the performance of all methods for that class.

Uses the new ROCCurve class to calculate the curves.
  • Loading branch information
ashlaban committed May 12, 2017
commit 7e14e1625d3ea3383d6b27ff24ad5e50c29a950c
1 change: 1 addition & 0 deletions tmva/tmva/inc/TMVA/ResultsMulticlass.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ namespace TMVA {
std::vector<Float_t>& GetAchievableEff(){return fAchievableEff;}
std::vector<Float_t>& GetAchievablePur(){return fAchievablePur;}
// histogramming
void CreateMulticlassPerformanceHistos(TString prefix);
void CreateMulticlassHistos( TString prefix, Int_t nbins, Int_t nbins_high);

Double_t EstimatorFunction( std::vector<Double_t> & );
Expand Down
298 changes: 152 additions & 146 deletions tmva/tmva/src/Factory.cxx

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions tmva/tmva/src/MethodBase.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,7 @@ void TMVA::MethodBase::AddMulticlassOutput(Types::ETreeType type)
TString histNamePrefix(GetTestvarName());
histNamePrefix += (type==Types::kTraining?"_Train":"_Test");
resMulticlass->CreateMulticlassHistos( histNamePrefix, fNbinsMVAoutput, fNbinsH );
resMulticlass->CreateMulticlassPerformanceHistos(histNamePrefix);
}

////////////////////////////////////////////////////////////////////////////////
Expand Down
61 changes: 61 additions & 0 deletions tmva/tmva/src/ResultsMulticlass.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ Class which takes the results of a multiclass classification
#include "TMVA/GeneticFitter.h"
#include "TMVA/MsgLogger.h"
#include "TMVA/Results.h"
#include "TMVA/ROCCurve.h"
#include "TMVA/Tools.h"
#include "TMVA/Types.h"

#include "TGraph.h"
#include "TH1F.h"

#include <limits>
Expand Down Expand Up @@ -176,6 +178,65 @@ std::vector<Double_t> TMVA::ResultsMulticlass::GetBestMultiClassCuts(UInt_t targ
return result;
}

////////////////////////////////////////////////////////////////////////////////
/// Create performance graphs for this classifier a multiclass setting.
/// Requires that the method has already been evaluated (that a resultset
/// already exists.)
///
/// Currently uses the new way of calculating ROC Curves. If anything looks
/// fishy, please contact the ROOT TMVA team.
///

void TMVA::ResultsMulticlass::CreateMulticlassPerformanceHistos(TString prefix)
{
DataSet *ds = GetDataSet();
ds->SetCurrentType(GetTreeType());
const DataSetInfo *dsi = GetDataSetInfo();

UInt_t numClasses = dsi->GetNClasses();

std::vector<std::vector<Float_t>> *rawMvaRes = GetValueVector();

for (size_t iClass = 0; iClass < numClasses; ++iClass) {
// Format data
// TODO: Replace with calls to GetMvaValuesPerClass
std::vector<Float_t> mvaRes;
std::vector<Bool_t> mvaResTypes;
std::vector<Float_t> mvaResWeights;

// Vector transpose due to values being stored as
// [ [0, 1, 2], [0, 1, 2], ... ]
// in ResultsMulticlass::GetValueVector.
mvaRes.reserve(rawMvaRes->size());
for (auto item : *rawMvaRes) {
mvaRes.push_back(item[iClass]);
}

auto eventCollection = ds->GetEventCollection();
mvaResTypes.reserve(eventCollection.size());
mvaResWeights.reserve(eventCollection.size());
for (auto ev : eventCollection) {
mvaResTypes.push_back(ev->GetClass() == iClass);
mvaResWeights.push_back(ev->GetWeight());
}

// Get ROC Curve
ROCCurve *roc = new ROCCurve(mvaRes, mvaResTypes, mvaResWeights);
TGraph *rocGraph = new TGraph(*(roc->GetROCCurve()));
delete roc;

// Style ROC Curve
TString className = dsi->GetClassInfo(iClass)->GetName();
TString name = Form("%s_rejBvsS_%s", prefix.Data(), className.Data());
TString title = Form("%s_%s", prefix.Data(), className.Data());
rocGraph->SetName(name);
rocGraph->SetTitle(title);

// Store ROC Curve
Store(rocGraph);
}
}

////////////////////////////////////////////////////////////////////////////////
/// this function fills the mva response histos for multiclass classification

Expand Down
2 changes: 1 addition & 1 deletion tmva/tmvagui/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ if(NOT CMAKE_PROJECT_NAME STREQUAL ROOT)
endif()

set(headers1 annconvergencetest.h deviations.h mvaeffs.h PlotFoams.h TMVAGui.h
BDTControlPlots.h correlationscatters.h efficiencies.h mvas.h probas.h
BDTControlPlots.h correlationscatters.h efficiencies.h efficienciesMulticlass.h mvas.h probas.h
BDT.h correlationscattersMultiClass.h likelihoodrefs.h mvasMulticlass.h regression_averagedevs.h TMVAMultiClassGui.h
BDT_Reg.h correlations.h mvaweights.h rulevisCorr.h TMVARegGui.h
BoostControlPlots.h correlationsMultiClass.h network.h rulevis.h variables.h
Expand Down
2 changes: 1 addition & 1 deletion tmva/tmvagui/Module.mk
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ TMVAGUIDO := $(TMVAGUIDS:.cxx=.o)
TMVAGUIDH := $(TMVAGUIDS:.cxx=.h)

TMVAGUIH1 := annconvergencetest.h deviations.h mvaeffs.h PlotFoams.h TMVAGui.h\
BDTControlPlots.h correlationscatters.h efficiencies.h mvas.h probas.h \
BDTControlPlots.h correlationscatters.h efficiencies.h efficienciesMulticlass.h mvas.h probas.h \
BDT.h correlationscattersMultiClass.h likelihoodrefs.h mvasMulticlass.h regression_averagedevs.h TMVAMultiClassGui.h\
BDT_Reg.h correlations.h mvaweights.h rulevisCorr.h TMVARegGui.h\
BoostControlPlots.h correlationsMultiClass.h network.h rulevis.h variables.h\
Expand Down
3 changes: 2 additions & 1 deletion tmva/tmvagui/inc/LinkDef.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
#pragma link C++ function TMVA::CorrGui;
#pragma link C++ function TMVA::CorrGuiMultiClass;
#pragma link C++ function TMVA::deviations;
#pragma link C++ function TMVA::efficiencies;
#pragma link C++ function TMVA::efficiencies;
#pragma link C++ function TMVA::efficienciesMulticlass;
#pragma link C++ function TMVA::likelihoodrefs;
#pragma link C++ function TMVA::MovieMaker;
#pragma link C++ defined_in "TMVA/mvaeffs.h";
Expand Down
25 changes: 25 additions & 0 deletions tmva/tmvagui/inc/TMVA/efficienciesMulticlass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef efficienciesMulticlass__HH
#define efficienciesMulticlass__HH

#include "tmvaglob.h"

class TCanvas;
class TDirectory;
class TFile;
class TGraph;
class TString;

namespace TMVA {

enum class EEfficiencyPlotType { kEffBvsEffS, kRejBvsEffS };

void efficienciesMulticlass(TString dataset, TString filename_input = "TMVAMulticlass.root",
EEfficiencyPlotType plotType = EEfficiencyPlotType::kRejBvsEffS,
Bool_t useTMVAStyle = kTRUE);

void plotEfficienciesMulticlass(EEfficiencyPlotType plotType = EEfficiencyPlotType::kRejBvsEffS,
TDirectory *BinDir = 0);

} // namespace TMVA

#endif
84 changes: 42 additions & 42 deletions tmva/tmvagui/src/TMVAMultiClassGui.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -189,63 +189,63 @@ void TMVA::TMVAMultiClassGui(const char* fName ,TString dataset)
buttonType, defaultRequiredClassifier );
/*
title = Form( "(%ic) Classifier Probability Distributions (test sample)", ic );
MultiClassActionButton( cbar,
MultiClassActionButton( cbar,
Form( "(%ic) Classifier Probability Distributions (test sample)", ic ),
Form( "TMVA::mvas(\"%s\",TMVA::kProbaType)", fName ),
"Plots the probability of each classifier for the test data (macro mvas(...,1))",
buttonType, defaultRequiredClassifier );

title =Form( "(%id) Classifier Rarity Distributions (test sample)", ic );
MultiClassActionButton( cbar,
MultiClassActionButton( cbar,
Form( "(%id) Classifier Rarity Distributions (test sample)", ic ),
Form( "TMVA::mvas(\"%s\",TMVA::kRarityType)", fName ),
"Plots the Rarity of each classifier for the test data (macro mvas(...,2)) - background distribution should be uniform",
buttonType, defaultRequiredClassifier );
"Plots the Rarity of each classifier for the test data (macro mvas(...,2)) - background distribution should be
uniform", buttonType, defaultRequiredClassifier );


title =Form( "(%ia) Classifier Cut Efficiencies", ++ic );
MultiClassActionButton( cbar,
MultiClassActionButton( cbar,
title,
Form( "TMVA::mvaeffs(\"%s\")", fName ),
"Plots signal and background efficiencies versus cut on classifier output (macro mvaeffs.cxx)",
buttonType, defaultRequiredClassifier );
*/

title = Form( "(%ib) Classifier Background Rejection vs Signal Efficiency (ROC curve)", ic );
MultiClassActionButton( cbar,
title,
Form( "TMVA::efficiencies(\"%s\")", fName ),
"Plots background rejection vs signal efficiencies (macro efficiencies.cxx) [\"ROC\" stands for \"Receiver Operation Characteristics\"]",
buttonType, defaultRequiredClassifier );


title = Form( "(%i) Parallel Coordinates (requires ROOT-version >= 5.17)", ++ic );
MultiClassActionButton( cbar,
title,
Form( "TMVA::paracoor(\"%s\")", fName ),
"Plots parallel coordinates for classifiers and input variables (macro paracoor.cxx, requires ROOT >= 5.17)",
buttonType, defaultRequiredClassifier );
title = Form("(%i) Classifier Background Rejection vs Signal Efficiency (ROC curve)", ++ic);
MultiClassActionButton(cbar, title, Form("TMVA::efficienciesMulticlass(\"%s\", \"%s\")", dataset.Data(), fName),
"Plots background rejection vs signal efficiencies (macro efficiencies.cxx) [\"ROC\" stands "
"for \"Receiver Operation Characteristics\"]",
buttonType, defaultRequiredClassifier);

// parallel coordinates only exist since ROOT 5.17
#if ROOT_VERSION_CODE < ROOT_VERSION(5,17,0)
TMVAMultiClassGui_inactiveButtons.push_back( title );
#endif


title =Form( "(%i) PDFs of Classifiers (requires \"CreateMVAPdfs\" option set)", ++ic );
MultiClassActionButton( cbar,
title,
Form( "TMVA::probas(\"%s\")", fName ),
"Plots the PDFs of the classifier output distributions for signal and background - if requested (macro probas.cxx)",
buttonType, defaultRequiredClassifier );
/*
title = Form( "(%i) Parallel Coordinates (requires ROOT-version >= 5.17)", ++ic );
MultiClassActionButton( cbar,
title,
Form( "TMVA::paracoor(\"%s\")", fName ),
"Plots parallel coordinates for classifiers and input variables (macro paracoor.cxx, requires ROOT >= 5.17)",
buttonType, defaultRequiredClassifier );

// parallel coordinates only exist since ROOT 5.17
#if ROOT_VERSION_CODE < ROOT_VERSION(5,17,0)
TMVAMultiClassGui_inactiveButtons.push_back( title );
#endif


title =Form( "(%i) PDFs of Classifiers (requires \"CreateMVAPdfs\" option set)", ++ic );
MultiClassActionButton( cbar,
title,
Form( "TMVA::probas(\"%s\")", fName ),
"Plots the PDFs of the classifier output distributions for signal and background - if requested (macro probas.cxx)",
buttonType, defaultRequiredClassifier );

title = Form( "(%i) Likelihood Reference Distributiuons", ++ic);
MultiClassActionButton( cbar,
title,
Form( "TMVA::likelihoodrefs(\"%s\")", fName ),
"Plots to verify the likelihood reference distributions (macro likelihoodrefs.cxx)",
buttonType, "Likelihood" );
*/

title = Form( "(%i) Likelihood Reference Distributiuons", ++ic);
MultiClassActionButton( cbar,
title,
Form( "TMVA::likelihoodrefs(\"%s\")", fName ),
"Plots to verify the likelihood reference distributions (macro likelihoodrefs.cxx)",
buttonType, "Likelihood" );
*/

title = Form( "(%ia) Network Architecture (MLP)", ++ic );
TString call = Form( "TMVA::network(\"%s\",\"%s\")",dataset.Data() , fName );
MultiClassActionButton( cbar,
Expand Down
Loading