diff --git a/tmva/tmva/inc/TMVA/DataLoader.h b/tmva/tmva/inc/TMVA/DataLoader.h index 20cbb81926437..40dfa30f280f6 100644 --- a/tmva/tmva/inc/TMVA/DataLoader.h +++ b/tmva/tmva/inc/TMVA/DataLoader.h @@ -46,6 +46,9 @@ #include "TMVA/DataSet.h" #endif +#include "TH1F.h" +#include "TH2.h" + class TFile; class TTree; class TDirectory; @@ -88,16 +91,16 @@ namespace TMVA { // special case: signal/background // Data input related - void SetInputTrees( const TString& signalFileName, const TString& backgroundFileName, + void SetInputTrees( const TString& signalFileName, const TString& backgroundFileName, Double_t signalWeight=1.0, Double_t backgroundWeight=1.0 ); void SetInputTrees( TTree* inputTree, const TCut& SigCut, const TCut& BgCut ); // Set input trees at once - void SetInputTrees( TTree* signal, TTree* background, + void SetInputTrees( TTree* signal, TTree* background, Double_t signalWeight=1.0, Double_t backgroundWeight=1.0) ; void AddSignalTree( TTree* signal, Double_t weight=1.0, Types::ETreeType treetype = Types::kMaxTreeType ); void AddSignalTree( TString datFileS, Double_t weight=1.0, Types::ETreeType treetype = Types::kMaxTreeType ); - void AddSignalTree( TTree* signal, Double_t weight, const TString& treetype ); + void AddSignalTree( TTree* signal, Double_t weight, const TString& treetype ); // ... depreciated, kept for backwards compatibility void SetSignalTree( TTree* signal, Double_t weight=1.0); @@ -113,9 +116,9 @@ namespace TMVA { void SetBackgroundWeightExpression( const TString& variable ); // special case: regression - void AddRegressionTree( TTree* tree, Double_t weight = 1.0, - Types::ETreeType treetype = Types::kMaxTreeType ) { - AddTree( tree, "Regression", weight, "", treetype ); + void AddRegressionTree( TTree* tree, Double_t weight = 1.0, + Types::ETreeType treetype = Types::kMaxTreeType ) { + AddTree( tree, "Regression", weight, "", treetype ); } // general @@ -157,10 +160,10 @@ namespace TMVA { void PrepareTrainingAndTestTree( const TCut& cut, const TString& splitOpt ); void PrepareTrainingAndTestTree( TCut sigcut, TCut bkgcut, const TString& splitOpt ); - // ... deprecated, kept for backwards compatibility + // ... deprecated, kept for backwards compatibility void PrepareTrainingAndTestTree( const TCut& cut, Int_t Ntrain, Int_t Ntest = -1 ); - void PrepareTrainingAndTestTree( const TCut& cut, Int_t NsigTrain, Int_t NbkgTrain, Int_t NsigTest, Int_t NbkgTest, + void PrepareTrainingAndTestTree( const TCut& cut, Int_t NsigTrain, Int_t NbkgTrain, Int_t NsigTest, Int_t NbkgTest, const TString& otherOpt="SplitMode=Random:!V" ); void PrepareTrainingAndTestTree( int foldNumber, Types::ETreeType tt ); @@ -169,11 +172,12 @@ namespace TMVA { void ValidationKFoldSet(); std::vector SplitSets(TTree * oldTree, int seedNum, int numFolds); - - + TH1F* GetInputVariableHist( const TString& className, const TString& variableName, UInt_t numBin, const TString& processTrfs); + TH2* GetCorrelationMatrix( const TString& className ); + private: - + DataInputHandler& DataInput() { return *fDataInputHandler; } DataSetInfo& DefaultDataSetInfo(); void SetInputTreesFromEventAssignTrees(); @@ -186,7 +190,7 @@ namespace TMVA { DataSetManager* fDataSetManager; // DSMTEST - + DataInputHandler* fDataInputHandler; std::vector fDefaultTrfs; //! list of transformations on default DataSet @@ -199,7 +203,7 @@ namespace TMVA { TString fName; //! name, used as directory in output // flag determining the way training and test data are assigned to DataLoader - enum DataAssignType { kUndefined = 0, + enum DataAssignType { kUndefined = 0, kAssignTrees, kAssignEvents }; DataAssignType fDataAssignType; //! flags for data assigning @@ -214,7 +218,7 @@ namespace TMVA { Int_t fATreeType; // type of event (=classIndex) Float_t fATreeWeight; // weight of the event Float_t* fATreeEvent; // event variables - + Types::EAnalysisType fAnalysisType; //! the training type protected: @@ -225,4 +229,3 @@ namespace TMVA { } // namespace TMVA #endif - diff --git a/tmva/tmva/inc/TMVA/MethodBase.h b/tmva/tmva/inc/TMVA/MethodBase.h index 6b87db1538a84..e3765f8ae07d9 100644 --- a/tmva/tmva/inc/TMVA/MethodBase.h +++ b/tmva/tmva/inc/TMVA/MethodBase.h @@ -88,10 +88,12 @@ namespace TMVA { class MethodCuts; class MethodBoost; class DataSetInfo; + class DataLoader; class MethodBase : virtual public IMethod, public Configurable { friend class Factory; + friend class DataLoader; public: @@ -182,7 +184,7 @@ namespace TMVA { // helper function to set errors to -1 void NoErrorCalc(Double_t* const err, Double_t* const errUpper); - // signal/background classification response for all current set of data + // signal/background classification response for all current set of data virtual std::vector GetMvaValues(Long64_t firstEvt = 0, Long64_t lastEvt = -1, Bool_t logProgress = false); diff --git a/tmva/tmva/src/DataLoader.cxx b/tmva/tmva/src/DataLoader.cxx index ce96a7d0fe158..31e0494759eee 100644 --- a/tmva/tmva/src/DataLoader.cxx +++ b/tmva/tmva/src/DataLoader.cxx @@ -1,4 +1,4 @@ -// @(#)root/tmva $Id$ +// @(#)root/tmva $Id$ // Author: Omar Zapata // Mentors: Lorenzo Moneta, Sergei Gleyzer //NOTE: Based on TMVA::Factory @@ -127,12 +127,12 @@ TMVA::DataSetInfo& TMVA::DataLoader::AddDataSet( const TString& dsiName ) DataSetInfo* dsi = fDataSetManager->GetDataSetInfo(dsiName); // DSMTEST if (dsi!=0) return *dsi; - + return fDataSetManager->AddDataSetInfo(*(new DataSetInfo(dsiName))); // DSMTEST } // ________________________________________________ -// the next functions are to assign events directly +// the next functions are to assign events directly //_______________________________________________________________________ TTree* TMVA::DataLoader::CreateEventAssignTrees( const TString& name ) @@ -167,42 +167,42 @@ TTree* TMVA::DataLoader::CreateEventAssignTrees( const TString& name ) } //_______________________________________________________________________ -void TMVA::DataLoader::AddSignalTrainingEvent( const std::vector& event, Double_t weight ) +void TMVA::DataLoader::AddSignalTrainingEvent( const std::vector& event, Double_t weight ) { // add signal training event AddEvent( "Signal", Types::kTraining, event, weight ); } //_______________________________________________________________________ -void TMVA::DataLoader::AddSignalTestEvent( const std::vector& event, Double_t weight ) +void TMVA::DataLoader::AddSignalTestEvent( const std::vector& event, Double_t weight ) { // add signal testing event AddEvent( "Signal", Types::kTesting, event, weight ); } //_______________________________________________________________________ -void TMVA::DataLoader::AddBackgroundTrainingEvent( const std::vector& event, Double_t weight ) +void TMVA::DataLoader::AddBackgroundTrainingEvent( const std::vector& event, Double_t weight ) { // add signal training event AddEvent( "Background", Types::kTraining, event, weight ); } //_______________________________________________________________________ -void TMVA::DataLoader::AddBackgroundTestEvent( const std::vector& event, Double_t weight ) +void TMVA::DataLoader::AddBackgroundTestEvent( const std::vector& event, Double_t weight ) { // add signal training event AddEvent( "Background", Types::kTesting, event, weight ); } //_______________________________________________________________________ -void TMVA::DataLoader::AddTrainingEvent( const TString& className, const std::vector& event, Double_t weight ) +void TMVA::DataLoader::AddTrainingEvent( const TString& className, const std::vector& event, Double_t weight ) { // add signal training event AddEvent( className, Types::kTraining, event, weight ); } //_______________________________________________________________________ -void TMVA::DataLoader::AddTestEvent( const TString& className, const std::vector& event, Double_t weight ) +void TMVA::DataLoader::AddTestEvent( const TString& className, const std::vector& event, Double_t weight ) { // add signal test event AddEvent( className, Types::kTesting, event, weight ); @@ -210,7 +210,7 @@ void TMVA::DataLoader::AddTestEvent( const TString& className, const std::vector //_______________________________________________________________________ void TMVA::DataLoader::AddEvent( const TString& className, Types::ETreeType tt, - const std::vector& event, Double_t weight ) + const std::vector& event, Double_t weight ) { // add event // vector event : the order of values is: variables + targets + spectators @@ -222,7 +222,7 @@ void TMVA::DataLoader::AddEvent( const TString& className, Types::ETreeType tt, if( fAnalysisType == Types::kNoAnalysisType && DefaultDataSetInfo().GetNClasses() > 2 ) fAnalysisType = Types::kMulticlass; - + if (clIndex>=fTrainAssignTree.size()) { fTrainAssignTree.resize(clIndex+1, 0); fTestAssignTree.resize(clIndex+1, 0); @@ -232,7 +232,7 @@ void TMVA::DataLoader::AddEvent( const TString& className, Types::ETreeType tt, fTrainAssignTree[clIndex] = CreateEventAssignTrees( Form("TrainAssignTree_%s", className.Data()) ); fTestAssignTree[clIndex] = CreateEventAssignTrees( Form("TestAssignTree_%s", className.Data()) ); } - + fATreeType = clIndex; fATreeWeight = weight; for (UInt_t ivar=0; ivar cannot interpret tree type: \"" << treetype + Log() << kFATAL << " cannot interpret tree type: \"" << treetype << "\" should be \"Training\" or \"Test\" or \"Training and Testing\"" << Endl; } AddTree( tree, className, weight, cut, tt ); } //_______________________________________________________________________ -void TMVA::DataLoader::AddTree( TTree* tree, const TString& className, Double_t weight, +void TMVA::DataLoader::AddTree( TTree* tree, const TString& className, Double_t weight, const TCut& cut, Types::ETreeType tt ) { if(!tree) @@ -293,7 +293,7 @@ void TMVA::DataLoader::AddTree( TTree* tree, const TString& className, Double_t if( fAnalysisType == Types::kNoAnalysisType && DefaultDataSetInfo().GetNClasses() > 2 ) fAnalysisType = Types::kMulticlass; - Log() << kINFO << "Add Tree " << tree->GetName() << " of type " << className + Log() << kINFO << "Add Tree " << tree->GetName() << " of type " << className << " with " << tree->GetEntries() << " events" << Endl; DataInput().AddTree( tree, className, weight, cut, tt ); } @@ -313,10 +313,10 @@ void TMVA::DataLoader::AddSignalTree( TString datFileS, Double_t weight, Types:: // create trees from these ascii files TTree* signalTree = new TTree( "TreeS", "Tree (S)" ); signalTree->ReadFile( datFileS ); - + Log() << kINFO << "Create TTree objects from ASCII input files ... \n- Signal file : \"" << datFileS << Endl; - + // number of signal events (used to compute significance) AddTree( signalTree, "Signal", weight, TCut(""), treetype ); } @@ -341,10 +341,10 @@ void TMVA::DataLoader::AddBackgroundTree( TString datFileB, Double_t weight, Typ // create trees from these ascii files TTree* bkgTree = new TTree( "TreeB", "Tree (B)" ); bkgTree->ReadFile( datFileB ); - + Log() << kINFO << "Create TTree objects from ASCII input files ... \n- Background file : \"" << datFileB << Endl; - + // number of signal events (used to compute significance) AddTree( bkgTree, "Background", weight, TCut(""), treetype ); } @@ -375,7 +375,7 @@ void TMVA::DataLoader::SetTree( TTree* tree, const TString& className, Double_t } //_______________________________________________________________________ -void TMVA::DataLoader::SetInputTrees( TTree* signal, TTree* background, +void TMVA::DataLoader::SetInputTrees( TTree* signal, TTree* background, Double_t signalWeight, Double_t backgroundWeight ) { // define the input trees for signal and background; no cuts are applied @@ -384,7 +384,7 @@ void TMVA::DataLoader::SetInputTrees( TTree* signal, TTree* background, } //_______________________________________________________________________ -void TMVA::DataLoader::SetInputTrees( const TString& datFileS, const TString& datFileB, +void TMVA::DataLoader::SetInputTrees( const TString& datFileS, const TString& datFileB, Double_t signalWeight, Double_t backgroundWeight ) { DataInput().AddTree( datFileS, "Signal", signalWeight ); @@ -395,18 +395,18 @@ void TMVA::DataLoader::SetInputTrees( const TString& datFileS, const TString& da void TMVA::DataLoader::SetInputTrees( TTree* inputTree, const TCut& SigCut, const TCut& BgCut ) { // define the input trees for signal and background from single input tree, - // containing both signal and background events distinguished by the type + // containing both signal and background events distinguished by the type // identifiers: SigCut and BgCut AddTree( inputTree, "Signal", 1.0, SigCut, Types::kMaxTreeType ); AddTree( inputTree, "Background", 1.0, BgCut , Types::kMaxTreeType ); } //_______________________________________________________________________ -void TMVA::DataLoader::AddVariable( const TString& expression, const TString& title, const TString& unit, +void TMVA::DataLoader::AddVariable( const TString& expression, const TString& title, const TString& unit, char type, Double_t min, Double_t max ) { // user inserts discriminating variable in data set info - DefaultDataSetInfo().AddVariable( expression, title, unit, min, max, type ); + DefaultDataSetInfo().AddVariable( expression, title, unit, min, max, type ); } //_______________________________________________________________________ @@ -414,11 +414,11 @@ void TMVA::DataLoader::AddVariable( const TString& expression, char type, Double_t min, Double_t max ) { // user inserts discriminating variable in data set info - DefaultDataSetInfo().AddVariable( expression, "", "", min, max, type ); + DefaultDataSetInfo().AddVariable( expression, "", "", min, max, type ); } //_______________________________________________________________________ -void TMVA::DataLoader::AddTarget( const TString& expression, const TString& title, const TString& unit, +void TMVA::DataLoader::AddTarget( const TString& expression, const TString& title, const TString& unit, Double_t min, Double_t max ) { // user inserts target in data set info @@ -426,52 +426,52 @@ void TMVA::DataLoader::AddTarget( const TString& expression, const TString& titl if( fAnalysisType == Types::kNoAnalysisType ) fAnalysisType = Types::kRegression; - DefaultDataSetInfo().AddTarget( expression, title, unit, min, max ); + DefaultDataSetInfo().AddTarget( expression, title, unit, min, max ); } //_______________________________________________________________________ -void TMVA::DataLoader::AddSpectator( const TString& expression, const TString& title, const TString& unit, +void TMVA::DataLoader::AddSpectator( const TString& expression, const TString& title, const TString& unit, Double_t min, Double_t max ) { // user inserts target in data set info - DefaultDataSetInfo().AddSpectator( expression, title, unit, min, max ); + DefaultDataSetInfo().AddSpectator( expression, title, unit, min, max ); } //_______________________________________________________________________ -TMVA::DataSetInfo& TMVA::DataLoader::DefaultDataSetInfo() -{ +TMVA::DataSetInfo& TMVA::DataLoader::DefaultDataSetInfo() +{ // default creation return AddDataSet( fName ); } //_______________________________________________________________________ -void TMVA::DataLoader::SetInputVariables( std::vector* theVariables ) -{ +void TMVA::DataLoader::SetInputVariables( std::vector* theVariables ) +{ // fill input variables in data set for (std::vector::iterator it=theVariables->begin(); it!=theVariables->end(); it++) AddVariable(*it); } //_______________________________________________________________________ -void TMVA::DataLoader::SetSignalWeightExpression( const TString& variable) -{ - DefaultDataSetInfo().SetWeightExpression(variable, "Signal"); +void TMVA::DataLoader::SetSignalWeightExpression( const TString& variable) +{ + DefaultDataSetInfo().SetWeightExpression(variable, "Signal"); } //_______________________________________________________________________ -void TMVA::DataLoader::SetBackgroundWeightExpression( const TString& variable) +void TMVA::DataLoader::SetBackgroundWeightExpression( const TString& variable) { DefaultDataSetInfo().SetWeightExpression(variable, "Background"); } //_______________________________________________________________________ -void TMVA::DataLoader::SetWeightExpression( const TString& variable, const TString& className ) +void TMVA::DataLoader::SetWeightExpression( const TString& variable, const TString& className ) { //Log() << kWarning << DefaultDataSetInfo().GetNClasses() /*fClasses.size()*/ << Endl; if (className=="") { SetSignalWeightExpression(variable); SetBackgroundWeightExpression(variable); - } + } else DefaultDataSetInfo().SetWeightExpression( variable, className ); } @@ -481,25 +481,25 @@ void TMVA::DataLoader::SetCut( const TString& cut, const TString& className ) { } //_______________________________________________________________________ -void TMVA::DataLoader::SetCut( const TCut& cut, const TString& className ) +void TMVA::DataLoader::SetCut( const TCut& cut, const TString& className ) { DefaultDataSetInfo().SetCut( cut, className ); } //_______________________________________________________________________ -void TMVA::DataLoader::AddCut( const TString& cut, const TString& className ) +void TMVA::DataLoader::AddCut( const TString& cut, const TString& className ) { AddCut( TCut(cut), className ); } //_______________________________________________________________________ -void TMVA::DataLoader::AddCut( const TCut& cut, const TString& className ) +void TMVA::DataLoader::AddCut( const TCut& cut, const TString& className ) { DefaultDataSetInfo().AddCut( cut, className ); } //_______________________________________________________________________ -void TMVA::DataLoader::PrepareTrainingAndTestTree( const TCut& cut, +void TMVA::DataLoader::PrepareTrainingAndTestTree( const TCut& cut, Int_t NsigTrain, Int_t NbkgTrain, Int_t NsigTest, Int_t NbkgTest, const TString& otherOpt ) { @@ -508,20 +508,20 @@ void TMVA::DataLoader::PrepareTrainingAndTestTree( const TCut& cut, AddCut( cut ); - DefaultDataSetInfo().SetSplitOptions( Form("nTrain_Signal=%i:nTrain_Background=%i:nTest_Signal=%i:nTest_Background=%i:%s", + DefaultDataSetInfo().SetSplitOptions( Form("nTrain_Signal=%i:nTrain_Background=%i:nTest_Signal=%i:nTest_Background=%i:%s", NsigTrain, NbkgTrain, NsigTest, NbkgTest, otherOpt.Data()) ); } //_______________________________________________________________________ void TMVA::DataLoader::PrepareTrainingAndTestTree( const TCut& cut, Int_t Ntrain, Int_t Ntest ) { - // prepare the training and test trees + // prepare the training and test trees // kept for backward compatibility SetInputTreesFromEventAssignTrees(); AddCut( cut ); - DefaultDataSetInfo().SetSplitOptions( Form("nTrain_Signal=%i:nTrain_Background=%i:nTest_Signal=%i:nTest_Background=%i:SplitMode=Random:EqualTrainSample:!V", + DefaultDataSetInfo().SetSplitOptions( Form("nTrain_Signal=%i:nTrain_Background=%i:nTest_Signal=%i:nTest_Background=%i:SplitMode=Random:EqualTrainSample:!V", Ntrain, Ntrain, Ntest, Ntest) ); } @@ -605,7 +605,7 @@ void TMVA::DataLoader::PrepareTrainingAndTestTree(int foldNumber, Types::ETreeTy void TMVA::DataLoader::MakeKFoldDataSet(int numberFolds) { - + UInt_t nSigTrees = DataInput().GetNSignalTrees(); UInt_t nBkgTrees = DataInput().GetNBackgroundTrees(); @@ -684,7 +684,7 @@ std::vector TMVA::DataLoader::SplitSets(TTree * oldTree, int seedNum, in TString vname = vars[ivar].GetExpression(); if(vars[ivar].GetExpression() != vars[ivar].GetLabel()){ varsSize--; - continue; + continue; } TBranch * branch = oldTree->GetBranch(vname); branches.push_back(branch); @@ -731,3 +731,77 @@ std::vector TMVA::DataLoader::SplitSets(TTree * oldTree, int seedNum, in return tempTrees; } + +//_______________________________________________________________________ +TH1F* TMVA::DataLoader::GetInputVariableHist(const TString& className, const TString& variableName, + UInt_t numBin, const TString& processTrfs="") +{ + DataSetInfo& dsinfo = DefaultDataSetInfo(); + VariableInfo* vi = nullptr; + UInt_t ivar = 0; + for(UInt_t i=0;iGetExpression() + " ("+className+")", numBin, vi->GetMin(), vi->GetMax()); + + UInt_t clsn = DefaultDataSetInfo().GetClassInfo(className)->GetNumber(); + DataSet *ds = DefaultDataSetInfo().GetDataSet(); + + + std::vector trfsDef = gTools().SplitString(processTrfs,';'); + std::vector trfs; + for (std::vector::iterator trfsDefIt = trfsDef.begin(); trfsDefIt!=trfsDef.end(); trfsDefIt++){ + trfs.push_back(new TMVA::TransformationHandler(dsinfo, "DataLoader")); + TMVA::MethodBase::CreateVariableTransforms( (*trfsDefIt), dsinfo, *(trfs.back()), Log()); + } + + const std::vector& inputEvents = dsinfo.GetDataSet()->GetEventCollection(); + const std::vector* transformed = nullptr, *tmp = nullptr; + //FIXME CalcTransformations calls PlotVariables: in my opinion here we shouldn't call that method + std::vector::iterator trfIt; + for(trfIt=trfs.begin(); trfIt != trfs.end(); trfIt++){ + if (transformed==nullptr){ + transformed = (*trfIt)->CalcTransformations(inputEvents, true); + } else { + tmp = (*trfIt)->CalcTransformations(*transformed, true); + for (auto it = transformed->begin(); it!=transformed->end(); it++) delete *it; + delete transformed; + transformed = tmp; + } + } + for(trfIt = trfs.begin(); trfIt != trfs.end(); trfIt++) delete *trfIt; + + const Event* event; + + if (transformed!=nullptr){ + for(UInt_t i=0;isize();i++){ + event = (*transformed)[i]; + if (event->GetClass() != clsn) continue; + h->Fill(event->GetValue(ivar)); + } + for (auto it = transformed->begin(); it!=transformed->end(); it++) delete *it; + delete transformed; + } else { + for(UInt_t i=0;iGetNEvents();i++){ + event = inputEvents[i]; + if (event->GetClass() != clsn) continue; + h->Fill(event->GetValue(ivar)); + } + } + return h; +} + +//_______________________________________________________________________ +TH2* TMVA::DataLoader::GetCorrelationMatrix(const TString& className) +{ + //returns the correlation matrix of datasets + const TMatrixD * m = DefaultDataSetInfo().CorrelationMatrix(className); + return DefaultDataSetInfo().CreateCorrelationMatrixHist(m, + "CorrelationMatrix"+className, "Correlation Matrix ("+className+")"); +}