Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tmva/tmva/inc/TMVA/MethodBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ namespace TMVA {
virtual Double_t GetTrainingEfficiency(const TString& );
virtual std::vector<Float_t> GetMulticlassEfficiency( std::vector<std::vector<Float_t> >& purity );
virtual std::vector<Float_t> GetMulticlassTrainingEfficiency(std::vector<std::vector<Float_t> >& purity );
virtual TMatrixD GetMulticlassConfusionMatrix(Double_t effB, Types::ETreeType type);
virtual Double_t GetSignificance() const;
virtual Double_t GetROCIntegral(TH1D *histS, TH1D *histB) const;
virtual Double_t GetROCIntegral(PDF *pdfS=0, PDF *pdfB=0) const;
Expand Down
2 changes: 2 additions & 0 deletions tmva/tmva/inc/TMVA/ROCCurve.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class ROCCurve {

~ROCCurve();

Double_t GetEffSForEffB(Double_t effB, const UInt_t num_points = 41);

Double_t GetROCIntegral(const UInt_t points = 41);
TGraph *GetROCCurve(const UInt_t points = 100); // n divisions = #points -1

Expand Down
3 changes: 3 additions & 0 deletions tmva/tmva/inc/TMVA/ResultsMulticlass.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ namespace TMVA {
Float_t GetAchievablePur(UInt_t cls){return fAchievablePur.at(cls);}
std::vector<Float_t>& GetAchievableEff(){return fAchievableEff;}
std::vector<Float_t>& GetAchievablePur(){return fAchievablePur;}

TMatrixD GetConfusionMatrix(Double_t effB);

// histogramming
void CreateMulticlassPerformanceHistos(TString prefix);
void CreateMulticlassHistos( TString prefix, Int_t nbins, Int_t nbins_high);
Expand Down
219 changes: 193 additions & 26 deletions tmva/tmva/src/Factory.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,11 @@ void TMVA::Factory::EvaluateAllMethods( void )
std::vector<std::vector<Float_t> > multiclass_testPur;
std::vector<std::vector<Float_t> > multiclass_trainPur;

// Multiclass confusion matrices.
std::vector<TMatrixD> multiclass_testConfusionEffB01;
std::vector<TMatrixD> multiclass_testConfusionEffB10;
std::vector<TMatrixD> multiclass_testConfusionEffB30;

std::vector<std::vector<Double_t> > biastrain(1); // "bias" of the regression on the training data
std::vector<std::vector<Double_t> > biastest(1); // "bias" of the regression on test data
std::vector<std::vector<Double_t> > devtrain(1); // "dev" of the regression on the training data
Expand Down Expand Up @@ -1311,9 +1316,15 @@ void TMVA::Factory::EvaluateAllMethods( void )
Log() << kINFO << "Evaluate multiclass classification method: " << theMethod->GetMethodName() << Endl;

theMethod->TestMulticlass();

// Find approximate optimal working point w.r.t. signalEfficiency * signalPurity.
multiclass_testEff.push_back(theMethod->GetMulticlassEfficiency(multiclass_testPur));

// FIXME: This code snippet is repeated in other branches
// Confusion matrix at three background efficiency levels
multiclass_testConfusionEffB01.push_back(theMethod->GetMulticlassConfusionMatrix(0.01, Types::kTesting));
multiclass_testConfusionEffB10.push_back(theMethod->GetMulticlassConfusionMatrix(0.10, Types::kTesting));
multiclass_testConfusionEffB30.push_back(theMethod->GetMulticlassConfusionMatrix(0.30, Types::kTesting));

if (not IsSilentFile()) {
Log() << kDEBUG << "\tWrite evaluation histograms to file" << Endl;
theMethod->WriteEvaluationHistosToFile(Types::kTesting);
Expand Down Expand Up @@ -1400,9 +1411,10 @@ void TMVA::Factory::EvaluateAllMethods( void )
rmstrainT[0] = vtmp[15];
minftestT[0] = vtmp[16];
minftrainT[0] = vtmp[17];
}
else if (doMulticlass) {
// TODO: fill in something meaningful
} else if (doMulticlass) {
// TODO: fill in something meaningful
// If there is some ranking of methods to be done it should be done here.
// However, this is not so easy to define for multiclass so it is left out for now.

}
else {
Expand Down Expand Up @@ -1661,33 +1673,188 @@ void TMVA::Factory::EvaluateAllMethods( void )
}
Log() << kINFO << hLine << Endl;
Log() << kINFO << Endl;
}
else if( doMulticlass ){
Log() << Endl;
TString hLine = "-------------------------------------------------------------------------------------------------------";
Log() << kINFO << "Evaluation results ranked by best signal efficiency times signal purity " << Endl;
Log() << kINFO << hLine << Endl;
// iterate over methods and evaluate
for (MVector::iterator itrMethod = methods->begin(); itrMethod != methods->end(); itrMethod++) {
MethodBase* theMethod = dynamic_cast<MethodBase*>(*itrMethod);
if(theMethod==0) continue;

TString header= "DataSet Name MVA Method ";
for(UInt_t icls = 0; icls<theMethod->fDataSetInfo.GetNClasses(); ++icls){
header += Form("%-12s ",theMethod->fDataSetInfo.GetClassInfo(icls)->GetName());
} else if (doMulticlass) {
// ====================================================================
// === Multiclass Output
// ====================================================================

TString hLine =
"-------------------------------------------------------------------------------------------------------";

// --- Acheivable signal efficiency * signal purity
// --------------------------------------------------------------------
Log() << kINFO << Endl;
Log() << kINFO << "Evaluation results ranked by best signal efficiency times signal purity " << Endl;
Log() << kINFO << hLine << Endl;

// iterate over methods and evaluate
for (MVector::iterator itrMethod = methods->begin(); itrMethod != methods->end(); itrMethod++) {
MethodBase *theMethod = dynamic_cast<MethodBase *>(*itrMethod);
if (theMethod == 0) {
continue;
}

TString header = "DataSet Name MVA Method ";
for (UInt_t icls = 0; icls < theMethod->fDataSetInfo.GetNClasses(); ++icls) {
header += Form("%-12s ", theMethod->fDataSetInfo.GetClassInfo(icls)->GetName());
}

Log() << kINFO << header << Endl;
Log() << kINFO << hLine << Endl;
for (Int_t i = 0; i < nmeth_used[0]; i++) {
TString res = Form("[%-14s] %-15s", theMethod->fDataSetInfo.GetName(), (const char *)mname[0][i]);
for (UInt_t icls = 0; icls < theMethod->fDataSetInfo.GetNClasses(); ++icls) {
res += Form("%#1.3f ", (multiclass_testEff[i][icls]) * (multiclass_testPur[i][icls]));
}
Log() << kINFO << res << Endl;
}

Log() << kINFO << hLine << Endl;
Log() << kINFO << Endl;
}
Log() << kINFO << header << Endl;

// --- 1 vs Rest ROC AUC, signal efficiency @ given background efficiency
// --------------------------------------------------------------------
TString header1 =
Form("%-15s%-15s%-10s%-10s%-10s%-10s", "Dataset", "MVA Method", "", "Sig eff", "Sig eff", "Sig eff");
TString header2 =
Form("%-15s%-15s%-10s%-10s%-10s%-10s", "Name:", "/ Class:", "ROC AUC", "@B=0.01", "@B=0.10", "@B=0.30");
Log() << kINFO << "1-vs-rest performance metrics per class" << Endl;
Log() << kINFO << hLine << Endl;
for (Int_t i=0; i<nmeth_used[0]; i++) {
TString res = Form("[%-14s] %-15s",theMethod->fDataSetInfo.GetName(),(const char*)mname[0][i]);
for(UInt_t icls = 0; icls<theMethod->fDataSetInfo.GetNClasses(); ++icls){
res += Form("%#1.3f ",(multiclass_testEff[i][icls])*(multiclass_testPur[i][icls]));
}
Log() << kINFO << res << Endl;
Log() << kINFO << Endl;
Log() << kINFO << "Considers the listed class as signal and the other classes" << Endl;
Log() << kINFO << "as background, reporting the resulting binary performance." << Endl;

Log() << kINFO << Endl;
Log() << kINFO << header1 << Endl;
Log() << kINFO << header2 << Endl;
for (Int_t k = 0; k < 2; k++) {
for (Int_t i = 0; i < nmeth_used[k]; i++) {
if (k == 1) {
mname[k][i].ReplaceAll("Variable_", "");
}

const TString datasetName = itrMap->first;
const TString mvaName = mname[k][i];

MethodBase *theMethod = dynamic_cast<MethodBase *>(GetMethod(datasetName, mvaName));
if (theMethod == 0) {
continue;
}

Log() << kINFO << Endl;
TString row = Form("%-15s%-15s", datasetName.Data(), mvaName.Data());
Log() << kINFO << row << Endl;

UInt_t numClasses = theMethod->fDataSetInfo.GetNClasses();
DataSet *dataset = theMethod->Data();
TMVA::Results *results = theMethod->Data()->GetResults(mname[k][i], Types::kTesting, Types::kMulticlass);

for (UInt_t iClass = 0; iClass < numClasses; ++iClass) {
std::vector<Float_t> mvaRes;
std::vector<Bool_t> mvaResType;
std::vector<Float_t> mvaResWeight;

std::vector<std::vector<Float_t>> *rawMvaRes =
dynamic_cast<ResultsMulticlass *>(results)->GetValueVector();

// Vector transpose due to values being stored as
// [ [0, 1, 2], [0, 1, 2], ... ]
// in ResultsMulticlass::GetValueVector.
mvaRes.reserve(rawMvaRes->size());
for (auto item : *rawMvaRes) {
mvaRes.push_back(item[iClass]);
}

auto eventCollection = dataset->GetEventCollection();
mvaResType.reserve(eventCollection.size());
mvaResWeight.reserve(eventCollection.size());
for (auto ev : eventCollection) {
mvaResType.push_back(ev->GetClass() == iClass);
mvaResWeight.push_back(ev->GetWeight());
}

ROCCurve rocCurve = ROCCurve(mvaRes, mvaResType, mvaResWeight);

const TString className = theMethod->DataInfo().GetClassInfo(iClass)->GetName();
const Double_t rocauc = rocCurve.GetROCIntegral();
const Double_t effB01 = rocCurve.GetEffSForEffB(0.01);
const Double_t effB10 = rocCurve.GetEffSForEffB(0.10);
const Double_t effB30 = rocCurve.GetEffSForEffB(0.30);
row = Form("%-15s%-15s%-10.3f%-10.3f%-10.3f%-10.3f", "", className.Data(), rocauc, effB01, effB10,
effB30);
Log() << kINFO << row << Endl;
}
}
}
Log() << kINFO << hLine << Endl;
Log() << kINFO << Endl;
}

// --- Confusion matrices
// --------------------------------------------------------------------
auto printMatrix = [](TMatrixD mat, std::vector<TString> classnames, UInt_t numClasses, MsgLogger &stream) {
// assert (classLabledWidth >= valueLabelWidth + 2)
// if (...) {Log() << kWARN << "..." << Endl; }

TString header = Form("%-12s", " ");
for (UInt_t iCol = 0; iCol < numClasses; ++iCol) {
header += Form(" %-12s", classnames[iCol].Data());
}
stream << kINFO << header << Endl;

for (UInt_t iRow = 0; iRow < numClasses; ++iRow) {
stream << kINFO << Form("%-12s", classnames[iRow].Data());

for (UInt_t iCol = 0; iCol < numClasses; ++iCol) {
if (iCol == iRow) {
stream << kINFO << Form(" %-12s", "-");
continue;
}

Double_t value = mat[iRow][iCol];
stream << kINFO << Form(" %-12.3f", value);
}
stream << kINFO << Endl;
}
};

Log() << kINFO << Endl;
Log() << kINFO << "Confusion matrices for all methods" << Endl;
Log() << kINFO << hLine << Endl;
Log() << kINFO << Endl;
Log() << kINFO << "Does a binary comparison between the two classes given by a " << Endl;
Log() << kINFO << "particular row-column combination. In each case, the class " << Endl;
Log() << kINFO << "given by the row is considered signal while the class given " << Endl;
Log() << kINFO << "by the column index is considered background." << Endl;
Log() << kINFO << Endl;
for (UInt_t iMethod = 0; iMethod < methods->size(); ++iMethod) {
MethodBase *theMethod = dynamic_cast<MethodBase *>(methods->at(iMethod));
if (theMethod == nullptr) {
continue;
}
UInt_t numClasses = theMethod->fDataSetInfo.GetNClasses();

std::vector<TString> classnames;
for (UInt_t iCls = 0; iCls < numClasses; ++iCls) {
classnames.push_back(theMethod->fDataSetInfo.GetClassInfo(iCls)->GetName());
}
Log() << kINFO << "Showing confusion matrix for method : " << Form("%-15s", (const char *)mname[0][iMethod])
<< Endl;
Log() << kINFO << "(Signal Efficiency for Background Efficiency 0.01%)" << Endl;
printMatrix(multiclass_testConfusionEffB01[iMethod], classnames, numClasses, Log());
Log() << kINFO << Endl;

Log() << kINFO << "(Signal Efficiency for Background Efficiency 0.10%)" << Endl;
printMatrix(multiclass_testConfusionEffB10[iMethod], classnames, numClasses, Log());
Log() << kINFO << Endl;

Log() << kINFO << "(Signal Efficiency for Background Efficiency 0.30%)" << Endl;
printMatrix(multiclass_testConfusionEffB30[iMethod], classnames, numClasses, Log());
Log() << kINFO << Endl;
}
Log() << kINFO << hLine << Endl;
Log() << kINFO << Endl;

} else {
// Binary classification
if (fROC) {
Expand Down
39 changes: 39 additions & 0 deletions tmva/tmva/src/MethodBase.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -2664,6 +2664,45 @@ std::vector<Float_t> TMVA::MethodBase::GetMulticlassTrainingEfficiency(std::vect
return resMulticlass->GetAchievableEff();
}

////////////////////////////////////////////////////////////////////////////////
/// Construct a confusion matrix for a multiclass classifier. The confusion
/// matrix compares, in turn, each class agaist all other classes in a pair-wise
/// fashion. In rows with index \f$ k_r = 0 ... K \f$, \f$ k_r \f$ is
/// considered signal for the sake of comparison and for each column
/// \f$ k_c = 0 ... K \f$ the corresponding class is considered background.
///
/// Note that the diagonal elements will be returned as NaN since this will
/// compare a class against itself.
///
/// \see TMVA::ResultsMulticlass::GetConfusionMatrix
///
/// \param[in] effB The background efficiency for which to evaluate.
/// \param[in] type The data set on which to evaluate (training, testing ...).
///
/// \return A matrix containing signal efficiencies for the given background
/// efficiency. The diagonal elements are NaN since this measure is
/// meaningless (comparing a class against itself).
///

TMatrixD TMVA::MethodBase::GetMulticlassConfusionMatrix(Double_t effB, Types::ETreeType type)
{
if (GetAnalysisType() != Types::kMulticlass) {
Log() << kFATAL << "Cannot get confusion matrix for non-multiclass analysis." << std::endl;
return TMatrixD(0, 0);
}

Data()->SetCurrentType(type);
ResultsMulticlass *resMulticlass =
dynamic_cast<ResultsMulticlass *>(Data()->GetResults(GetMethodName(), type, Types::kMulticlass));

if (resMulticlass == nullptr) {
Log() << kFATAL << Form("Dataset[%s] : ", DataInfo().GetName())
<< "unable to create pointer in GetMulticlassEfficiency, exiting." << Endl;
return TMatrixD(0, 0);
}

return resMulticlass->GetConfusionMatrix(effB);
}

////////////////////////////////////////////////////////////////////////////////
/// compute significance of mean difference
Expand Down
34 changes: 34 additions & 0 deletions tmva/tmva/src/ROCCurve.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@

*/
#include "TMVA/Tools.h"
#include "TMVA/TSpline1.h"
#include "TMVA/ROCCurve.h"
#include "TMVA/Config.h"
#include "TMVA/Version.h"
#include "TMVA/MsgLogger.h"
#include "TGraph.h"
#include "TMath.h"

#include <vector>
#include <cassert>
Expand Down Expand Up @@ -184,6 +186,38 @@ std::vector<Double_t> TMVA::ROCCurve::ComputeSensitivity(const UInt_t num_points
return sensitivity_vector;
}

////////////////////////////////////////////////////////////////////////////////
/// Calculate the signal efficiency (sensitivity) for a given background
/// efficiency (sensitivity).
///
/// @param effB Background efficiency for which to calculate signal
/// efficiency.
/// @param num_points Number of points used for the underlying histogram.
/// The number of bins will be num_points - 1.
///

Double_t TMVA::ROCCurve::GetEffSForEffB(Double_t effB, const UInt_t num_points)
{
assert(0.0 <= effB and effB <= 1.0);

auto effS_vec = ComputeSensitivity(num_points);
auto effB_vec = ComputeSpecificity(num_points);

// Specificity is actually rejB, so we need to transform it.
auto complement = [](Double_t x) { return 1 - x; };
std::transform(effB_vec.begin(), effB_vec.end(), effB_vec.begin(), complement);

// Since TSpline1 uses binary search (and assumes ascending sorting) we must ensure this.
std::reverse(effS_vec.begin(), effS_vec.end());
std::reverse(effB_vec.begin(), effB_vec.end());

TGraph *graph = new TGraph(effS_vec.size(), &effB_vec[0], &effS_vec[0]);

// TSpline1 does linear interpolation of ROC curve
TSpline1 rocSpline = TSpline1("", graph);
return rocSpline.Eval(effB);
}

////////////////////////////////////////////////////////////////////////////////
/// Calculates the ROC integral (AUC)
///
Expand Down
Loading