Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add doc and explicit multiclass support for Factory:GetROCIntegral
  • Loading branch information
ashlaban committed Jun 6, 2017
commit 284f1b5cc83ae9698a57798bd6d76a14a1ffdf21
31 changes: 27 additions & 4 deletions tmva/tmva/src/Factory.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -660,12 +660,23 @@ std::map<TString,Double_t> TMVA::Factory::OptimizeAllMethods(TString fomType, TS

}

////////////////////////////////////////////////////////////////////////////////
/// Private method to generate an instance of a ROCCurve regardless of
/// analysis type.
///
/// \note You own the retured pointer.
///

TMVA::ROCCurve * TMVA::Factory::GetROC(TMVA::DataLoader *loader, TString theMethodName, UInt_t iClass) {
return GetROC((TString)loader->GetName(), theMethodName, iClass);
}

///////
/// NOTE: You own the retured pointer.
////////////////////////////////////////////////////////////////////////////////
/// Private method to generate an instance of a ROCCurve regardless of
/// analysis type.
///
/// \note You own the retured pointer.
///

TMVA::ROCCurve * TMVA::Factory::GetROC(TString datasetname, TString theMethodName, UInt_t iClass) {
if (fMethodsMap.find(datasetname) == fMethodsMap.end()) {
Expand Down Expand Up @@ -739,13 +750,25 @@ TMVA::ROCCurve * TMVA::Factory::GetROC(TString datasetname, TString theMethodNam
}

////////////////////////////////////////////////////////////////////////////////
/// Calculate the integral of the ROC curve, also known as the area under curve
/// (AUC), for a given method.
///
/// Argument iClass specifies the class to generate the ROC curve in a
/// multiclass setting. It is ignored for binary classification.
///

Double_t TMVA::Factory::GetROCIntegral(TMVA::DataLoader *loader, TString theMethodName, UInt_t iClass)
{
return GetROCIntegral((TString)loader->GetName(), theMethodName, iClass);
}

////////////////////////////////////////////////////////////////////////////////
/// Calculate the integral of the ROC curve, also known as the area under curve
/// (AUC), for a given method.
///
/// Argument iClass specifies the class to generate the ROC curve in a
/// multiclass setting. It is ignored for binary classification.
///

Double_t TMVA::Factory::GetROCIntegral(TString datasetname, TString theMethodName, UInt_t iClass)
{
Expand All @@ -759,9 +782,9 @@ Double_t TMVA::Factory::GetROCIntegral(TString datasetname, TString theMethodNam
return 0;
}

std::set<Types::EAnalysisType> allowedAnalysisTypes = {Types::kClassification/*, Types::kMulticlass*/};
std::set<Types::EAnalysisType> allowedAnalysisTypes = {Types::kClassification, Types::kMulticlass};
if ( allowedAnalysisTypes.count(this->fAnalysisType) == 0 ) {
Log() << kERROR << Form("Can only generate ROC integral for analysis type kClassification."/*"and kMulticlass."*/) << Endl;
Log() << kERROR << Form("Can only generate ROC integral for analysis type kClassification. and kMulticlass.") << Endl;
return 0;
}

Expand Down