Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Take weights into consideration in text output.
  • Loading branch information
ashlaban committed Jun 6, 2017
commit 4d9ba3b2ee3224cf4624eb10f04af6e9f7174b95
9 changes: 7 additions & 2 deletions tmva/tmva/inc/TMVA/Factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ namespace TMVA {
class DataSetInfo;
class DataSetManager;
class DataLoader;
class ROCCurve;
class VariableTransformBase;


Expand Down Expand Up @@ -149,8 +150,8 @@ namespace TMVA {
Bool_t IsSilentFile();
Bool_t IsModelPersistence();

Double_t GetROCIntegral(DataLoader *loader,TString theMethodName);
Double_t GetROCIntegral(TString datasetname,TString theMethodName);
Double_t GetROCIntegral(DataLoader *loader,TString theMethodName, UInt_t iClass=0);
Double_t GetROCIntegral(TString datasetname,TString theMethodName, UInt_t iClass=0);

// Methods to get a TGraph for an indicated method in dataset.
// Optional title and axis added with fLegend=kTRUE.
Expand Down Expand Up @@ -179,6 +180,10 @@ namespace TMVA {
TH1F* EvaluateImportanceRandom( DataLoader *loader,UInt_t nseeds, Types::EMVA theMethod, TString methodTitle, const char *theOption = "" );

TH1F* GetImportance(const int nbits,std::vector<Double_t> importances,std::vector<TString> varNames);

// Helpers for public facing ROC methods
ROCCurve * GetROC(DataLoader *loader, TString theMethodName, UInt_t iClass=0);
ROCCurve * GetROC(TString datasetname, TString theMethodName, UInt_t iClass=0);

void WriteDataInformation(DataSetInfo& fDataSetInfo);

Expand Down
204 changes: 115 additions & 89 deletions tmva/tmva/src/Factory.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -660,16 +660,94 @@ std::map<TString,Double_t> TMVA::Factory::OptimizeAllMethods(TString fomType, TS

}

TMVA::ROCCurve * TMVA::Factory::GetROC(TMVA::DataLoader *loader, TString theMethodName, UInt_t iClass) {
return GetROC((TString)loader->GetName(), theMethodName, iClass);
}

///////
/// NOTE: You own the retured pointer.

TMVA::ROCCurve * TMVA::Factory::GetROC(TString datasetname, TString theMethodName, UInt_t iClass) {
if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
return nullptr;
}

if ( ! this->HasMethod(datasetname, theMethodName) ) {
Log() << kERROR << Form("Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data()) << Endl;
return nullptr;
}

std::set<Types::EAnalysisType> allowedAnalysisTypes = {Types::kClassification, Types::kMulticlass};
if ( allowedAnalysisTypes.count(this->fAnalysisType) == 0 ) {
Log() << kERROR << Form("Can only generate ROC curves for analysis type kClassification and kMulticlass.") << Endl;
return nullptr;
}

TMVA::MethodBase *method = dynamic_cast<TMVA::MethodBase *>( this->GetMethod(datasetname, theMethodName) );
TMVA::DataSet *dataset = method->Data();
TMVA::Results *results = dataset->GetResults(theMethodName, Types::kTesting, this->fAnalysisType);

UInt_t nClasses = method->DataInfo().GetNClasses();
if ( this->fAnalysisType == Types::kMulticlass && iClass >= nClasses ) {
Log() << kERROR << Form("Given class number (iClass = %i) does not exist. There are %i classes in dataset.", iClass, nClasses) << Endl;
return nullptr;
}

TMVA::ROCCurve * rocCurve = nullptr;
if (this->fAnalysisType == Types::kClassification) {

std::vector<Float_t> *mvaRes = dynamic_cast<ResultsClassification *>(results)->GetValueVector();
std::vector<Bool_t> *mvaResTypes = dynamic_cast<ResultsClassification *>(results)->GetValueVectorTypes();
std::vector<Float_t> mvaResWeights;

auto eventCollection = dataset->GetEventCollection(Types::kTesting);
mvaResWeights.reserve(eventCollection.size());
for (auto ev : eventCollection) {
mvaResWeights.push_back(ev->GetWeight());
}

rocCurve = new TMVA::ROCCurve(*mvaRes, *mvaResTypes, mvaResWeights);

} else if (this->fAnalysisType == Types::kMulticlass) {
std::vector<Float_t> mvaRes;
std::vector<Bool_t> mvaResTypes;
std::vector<Float_t> mvaResWeights;

std::vector<std::vector<Float_t>> * rawMvaRes = dynamic_cast<ResultsMulticlass *>(results)->GetValueVector();

// Vector transpose due to values being stored as
// [ [0, 1, 2], [0, 1, 2], ... ]
// in ResultsMulticlass::GetValueVector.
mvaRes.reserve(rawMvaRes->size());
for (auto item : *rawMvaRes) {
mvaRes.push_back(item[iClass]);
}

auto eventCollection = dataset->GetEventCollection(Types::kTesting);
mvaResTypes.reserve(eventCollection.size());
mvaResWeights.reserve(eventCollection.size());
for (auto ev : eventCollection) {
mvaResTypes.push_back(ev->GetClass() == iClass);
mvaResWeights.push_back(ev->GetWeight());
}

rocCurve = new TMVA::ROCCurve(mvaRes, mvaResTypes, mvaResWeights);
}

return rocCurve;
}

////////////////////////////////////////////////////////////////////////////////

Double_t TMVA::Factory::GetROCIntegral(TMVA::DataLoader *loader, TString theMethodName)
Double_t TMVA::Factory::GetROCIntegral(TMVA::DataLoader *loader, TString theMethodName, UInt_t iClass)
{
return GetROCIntegral((TString)loader->GetName(),theMethodName);
return GetROCIntegral((TString)loader->GetName(), theMethodName, iClass);
}

////////////////////////////////////////////////////////////////////////////////

Double_t TMVA::Factory::GetROCIntegral(TString datasetname, TString theMethodName)
Double_t TMVA::Factory::GetROCIntegral(TString datasetname, TString theMethodName, UInt_t iClass)
{
if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
Log() << kERROR << Form("DataSet = %s not found in methods map.", datasetname.Data()) << Endl;
Expand All @@ -687,18 +765,16 @@ Double_t TMVA::Factory::GetROCIntegral(TString datasetname, TString theMethodNam
return 0;
}

TMVA::MethodBase *method = dynamic_cast<TMVA::MethodBase *>( this->GetMethod(datasetname, theMethodName) );
TMVA::Results *results = method->Data()->GetResults(theMethodName, Types::kTesting, Types::kClassification);

std::vector<Float_t> *mvaRes = dynamic_cast<ResultsClassification *>(results)->GetValueVector();
std::vector<Bool_t> *mvaResType = dynamic_cast<ResultsClassification *>(results)->GetValueVectorTypes();

TMVA::ROCCurve *fROCCurve = new TMVA::ROCCurve(*mvaRes, *mvaResType);
if (!fROCCurve) Log() << kFATAL << Form("ROCCurve object was not created in Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data()) << Endl;
TMVA::ROCCurve *rocCurve = GetROC(datasetname, theMethodName, iClass);
if (!rocCurve) {
Log() << kFATAL << Form("ROCCurve object was not created in Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data()) << Endl;
}

Double_t fROCalcValue = fROCCurve->GetROCIntegral();
Int_t npoints = TMVA::gConfig().fVariablePlotting.fNbinsXOfROCCurve + 1;
Double_t rocIntegral = rocCurve->GetROCIntegral(npoints);
delete rocCurve;

return fROCalcValue;
return rocIntegral;
}

////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -751,60 +827,10 @@ TGraph* TMVA::Factory::GetROCCurve(TString datasetname, TString theMethodName, B
Log() << kERROR << Form("Can only generate ROC curves for analysis type kClassification and kMulticlass.") << Endl;
return nullptr;
}

TMVA::MethodBase *method = dynamic_cast<TMVA::MethodBase *>( this->GetMethod(datasetname, theMethodName) );
TMVA::DataSet *dataset = method->Data();
TMVA::Results *results = dataset->GetResults(theMethodName, Types::kTesting, this->fAnalysisType);

UInt_t nClasses = method->DataInfo().GetNClasses();
if ( this->fAnalysisType == Types::kMulticlass && iClass >= nClasses ) {
Log() << kERROR << Form("Given class number (iClass = %i) does not exist. There are %i classes in dataset.", iClass, nClasses) << Endl;
return nullptr;
}

TMVA::ROCCurve *rocCurve = nullptr;
TMVA::ROCCurve *rocCurve = GetROC(datasetname, theMethodName, iClass);
TGraph *graph = nullptr;

if (this->fAnalysisType == Types::kClassification) {

std::vector<Float_t> *mvaRes = dynamic_cast<ResultsClassification *>(results)->GetValueVector();
std::vector<Bool_t> *mvaResType = dynamic_cast<ResultsClassification *>(results)->GetValueVectorTypes();
std::vector<Float_t> mvaResWeights;

auto eventCollection = dataset->GetEventCollection();
mvaResWeights.reserve(eventCollection.size());
for (auto ev : eventCollection) {
mvaResWeights.push_back(ev->GetWeight());
}

rocCurve = new TMVA::ROCCurve(*mvaRes, *mvaResType, mvaResWeights);

} else if (this->fAnalysisType == Types::kMulticlass) {
std::vector<Float_t> mvaRes;
std::vector<Bool_t> mvaResTypes;
std::vector<Float_t> mvaResWeights;

std::vector<std::vector<Float_t>> * rawMvaRes = dynamic_cast<ResultsMulticlass *>(results)->GetValueVector();

// Vector transpose due to values being stored as
// [ [0, 1, 2], [0, 1, 2], ... ]
// in ResultsMulticlass::GetValueVector.
mvaRes.reserve(rawMvaRes->size());
for (auto item : *rawMvaRes) {
mvaRes.push_back(item[iClass]);
}

auto eventCollection = dataset->GetEventCollection();
mvaResTypes.reserve(eventCollection.size());
mvaResWeights.reserve(eventCollection.size());
for (auto ev : eventCollection) {
mvaResTypes.push_back(ev->GetClass() == iClass);
mvaResWeights.push_back(ev->GetWeight());
}

rocCurve = new TMVA::ROCCurve(mvaRes, mvaResTypes, mvaResWeights);
}

if ( ! rocCurve ) {
Log() << kFATAL << Form("ROCCurve object was not created in Method = %s not found with Dataset = %s ", theMethodName.Data(), datasetname.Data()) << Endl;
return nullptr;
Expand Down Expand Up @@ -1878,51 +1904,51 @@ void TMVA::Factory::EvaluateAllMethods( void )
Log() << kINFO << "Input Variables: " << Endl << hLine << Endl;
}
for (Int_t i = 0; i < nmeth_used[k]; i++) {
if (k == 1) mname[k][i].ReplaceAll("Variable_", "");
TString datasetName = itrMap->first;
TString methodName = mname[k][i];

MethodBase *theMethod = dynamic_cast<MethodBase *>(GetMethod(itrMap->first, mname[k][i]));
if (theMethod == 0) continue;
TMVA::Results *results =
theMethod->Data()->GetResults(mname[k][i], Types::kTesting, Types::kClassification);
std::vector<Float_t> *mvaRes = dynamic_cast<ResultsClassification *>(results)->GetValueVector();
std::vector<Bool_t> *mvaResType =
dynamic_cast<ResultsClassification *>(results)->GetValueVectorTypes();
Double_t fROCalcValue = 0;
TMVA::ROCCurve *fROCCurve = nullptr;
if (k == 1) {
methodName.ReplaceAll("Variable_", "");
}

MethodBase *theMethod = dynamic_cast<MethodBase *>(GetMethod(datasetName, methodName));
if (theMethod == 0) {
continue;
}

TMVA::DataSet *dataset = theMethod->Data();
TMVA::Results *results = dataset->GetResults(methodName, Types::kTesting, this->fAnalysisType);
std::vector<Bool_t> *mvaResType = dynamic_cast<ResultsClassification *>(results)->GetValueVectorTypes();

Double_t rocIntegral = 0.0;
if (mvaResType->size() != 0) {
fROCCurve = new TMVA::ROCCurve(*mvaRes, *mvaResType);
fROCalcValue = fROCCurve->GetROCIntegral();
rocIntegral = GetROCIntegral(datasetName, methodName);
}

if (sep[k][i] < 0 || sig[k][i] < 0) {
// cannot compute separation/significance -> no MVA (usually for Cuts)
Log() << kINFO << Form("%-13s %-15s: %#1.3f", itrMap->first.Data(), (const char *)mname[k][i],
effArea[k][i])
<< Endl;
Log() << kINFO << Form("%-13s %-15s: %#1.3f", datasetName.Data(), methodName.Data(), effArea[k][i]) << Endl;

// Log() << kDEBUG << Form("%-20s %-15s: %#1.3f(%02i) %#1.3f(%02i) %#1.3f(%02i)
// %#1.3f %#1.3f | -- --",
// itrMap->first.Data(),
// (const char*)mname[k][i],
// datasetName.Data(),
// methodName.Data(),
// eff01[k][i], Int_t(1000*eff01err[k][i]),
// eff10[k][i], Int_t(1000*eff10err[k][i]),
// eff30[k][i], Int_t(1000*eff30err[k][i]),
// effArea[k][i],fROCalcValue) << Endl;
// effArea[k][i],rocIntegral) << Endl;
} else {
Log() << kINFO
<< Form("%-13s %-15s: %#1.3f", itrMap->first.Data(), (const char *)mname[k][i], fROCalcValue)
<< Endl;
Log() << kINFO << Form("%-13s %-15s: %#1.3f", datasetName.Data(), methodName.Data(), rocIntegral) << Endl;
// Log() << kDEBUG << Form("%-20s %-15s: %#1.3f(%02i) %#1.3f(%02i) %#1.3f(%02i)
// %#1.3f %#1.3f | %#1.3f %#1.3f",
// itrMap->first.Data(),
// (const char*)mname[k][i],
// datasetName.Data(),
// methodName.Data(),
// eff01[k][i], Int_t(1000*eff01err[k][i]),
// eff10[k][i], Int_t(1000*eff10err[k][i]),
// eff30[k][i], Int_t(1000*eff30err[k][i]),
// effArea[k][i],fROCalcValue,
// effArea[k][i],rocIntegral,
// sep[k][i], sig[k][i]) << Endl;
}
if (fROCCurve) delete fROCCurve;
}
}
Log() << kINFO << hLine << Endl;
Expand Down
1 change: 1 addition & 0 deletions tmva/tmva/src/ROCCurve.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ std::vector<Double_t> TMVA::ROCCurve::ComputeSpecificity(const UInt_t num_points
specificity_vector.push_back(specificity);
}


specificity_vector.push_back(1.0);
return specificity_vector;
}
Expand Down