Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
added new method to TMVA::DataLoader
  • Loading branch information
qati committed May 23, 2016
commit b5082243e43a8533ccf973d2667b2f68e496af4b
31 changes: 16 additions & 15 deletions tmva/tmva/inc/TMVA/DataLoader.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
#include "TMVA/DataSet.h"
#endif

#include "TH2.h"

class TFile;
class TTree;
class TDirectory;
Expand Down Expand Up @@ -88,16 +90,16 @@ namespace TMVA {
// special case: signal/background

// Data input related
void SetInputTrees( const TString& signalFileName, const TString& backgroundFileName,
void SetInputTrees( const TString& signalFileName, const TString& backgroundFileName,
Double_t signalWeight=1.0, Double_t backgroundWeight=1.0 );
void SetInputTrees( TTree* inputTree, const TCut& SigCut, const TCut& BgCut );
// Set input trees at once
void SetInputTrees( TTree* signal, TTree* background,
void SetInputTrees( TTree* signal, TTree* background,
Double_t signalWeight=1.0, Double_t backgroundWeight=1.0) ;

void AddSignalTree( TTree* signal, Double_t weight=1.0, Types::ETreeType treetype = Types::kMaxTreeType );
void AddSignalTree( TString datFileS, Double_t weight=1.0, Types::ETreeType treetype = Types::kMaxTreeType );
void AddSignalTree( TTree* signal, Double_t weight, const TString& treetype );
void AddSignalTree( TTree* signal, Double_t weight, const TString& treetype );

// ... depreciated, kept for backwards compatibility
void SetSignalTree( TTree* signal, Double_t weight=1.0);
Expand All @@ -113,9 +115,9 @@ namespace TMVA {
void SetBackgroundWeightExpression( const TString& variable );

// special case: regression
void AddRegressionTree( TTree* tree, Double_t weight = 1.0,
Types::ETreeType treetype = Types::kMaxTreeType ) {
AddTree( tree, "Regression", weight, "", treetype );
void AddRegressionTree( TTree* tree, Double_t weight = 1.0,
Types::ETreeType treetype = Types::kMaxTreeType ) {
AddTree( tree, "Regression", weight, "", treetype );
}

// general
Expand Down Expand Up @@ -157,10 +159,10 @@ namespace TMVA {
void PrepareTrainingAndTestTree( const TCut& cut, const TString& splitOpt );
void PrepareTrainingAndTestTree( TCut sigcut, TCut bkgcut, const TString& splitOpt );

// ... deprecated, kept for backwards compatibility
// ... deprecated, kept for backwards compatibility
void PrepareTrainingAndTestTree( const TCut& cut, Int_t Ntrain, Int_t Ntest = -1 );

void PrepareTrainingAndTestTree( const TCut& cut, Int_t NsigTrain, Int_t NbkgTrain, Int_t NsigTest, Int_t NbkgTest,
void PrepareTrainingAndTestTree( const TCut& cut, Int_t NsigTrain, Int_t NbkgTrain, Int_t NsigTest, Int_t NbkgTest,
const TString& otherOpt="SplitMode=Random:!V" );

void PrepareTrainingAndTestTree( int foldNumber, Types::ETreeType tt );
Expand All @@ -169,11 +171,11 @@ namespace TMVA {
void ValidationKFoldSet();
std::vector<TTree*> SplitSets(TTree * oldTree, int seedNum, int numFolds);


TH2* GetCorrelationMatrix(const TString& className);

private:


DataInputHandler& DataInput() { return *fDataInputHandler; }
DataSetInfo& DefaultDataSetInfo();
void SetInputTreesFromEventAssignTrees();
Expand All @@ -186,7 +188,7 @@ namespace TMVA {

DataSetManager* fDataSetManager; // DSMTEST


DataInputHandler* fDataInputHandler;

std::vector<TMVA::VariableTransformBase*> fDefaultTrfs; //! list of transformations on default DataSet
Expand All @@ -199,7 +201,7 @@ namespace TMVA {
TString fName; //! name, used as directory in output

// flag determining the way training and test data are assigned to DataLoader
enum DataAssignType { kUndefined = 0,
enum DataAssignType { kUndefined = 0,
kAssignTrees,
kAssignEvents };
DataAssignType fDataAssignType; //! flags for data assigning
Expand All @@ -214,7 +216,7 @@ namespace TMVA {
Int_t fATreeType; // type of event (=classIndex)
Float_t fATreeWeight; // weight of the event
Float_t* fATreeEvent; // event variables

Types::EAnalysisType fAnalysisType; //! the training type

protected:
Expand All @@ -225,4 +227,3 @@ namespace TMVA {
} // namespace TMVA

#endif

Loading