Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions tmva/tmva/inc/TMVA/DataLoader.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
#include "TMVA/DataSet.h"
#endif

#include "TH1F.h"
#include "TH2.h"

class TFile;
class TTree;
class TDirectory;
Expand Down Expand Up @@ -88,16 +91,16 @@ namespace TMVA {
// special case: signal/background

// Data input related
void SetInputTrees( const TString& signalFileName, const TString& backgroundFileName,
void SetInputTrees( const TString& signalFileName, const TString& backgroundFileName,
Double_t signalWeight=1.0, Double_t backgroundWeight=1.0 );
void SetInputTrees( TTree* inputTree, const TCut& SigCut, const TCut& BgCut );
// Set input trees at once
void SetInputTrees( TTree* signal, TTree* background,
void SetInputTrees( TTree* signal, TTree* background,
Double_t signalWeight=1.0, Double_t backgroundWeight=1.0) ;

void AddSignalTree( TTree* signal, Double_t weight=1.0, Types::ETreeType treetype = Types::kMaxTreeType );
void AddSignalTree( TString datFileS, Double_t weight=1.0, Types::ETreeType treetype = Types::kMaxTreeType );
void AddSignalTree( TTree* signal, Double_t weight, const TString& treetype );
void AddSignalTree( TTree* signal, Double_t weight, const TString& treetype );

// ... depreciated, kept for backwards compatibility
void SetSignalTree( TTree* signal, Double_t weight=1.0);
Expand All @@ -113,9 +116,9 @@ namespace TMVA {
void SetBackgroundWeightExpression( const TString& variable );

// special case: regression
void AddRegressionTree( TTree* tree, Double_t weight = 1.0,
Types::ETreeType treetype = Types::kMaxTreeType ) {
AddTree( tree, "Regression", weight, "", treetype );
void AddRegressionTree( TTree* tree, Double_t weight = 1.0,
Types::ETreeType treetype = Types::kMaxTreeType ) {
AddTree( tree, "Regression", weight, "", treetype );
}

// general
Expand Down Expand Up @@ -157,10 +160,10 @@ namespace TMVA {
void PrepareTrainingAndTestTree( const TCut& cut, const TString& splitOpt );
void PrepareTrainingAndTestTree( TCut sigcut, TCut bkgcut, const TString& splitOpt );

// ... deprecated, kept for backwards compatibility
// ... deprecated, kept for backwards compatibility
void PrepareTrainingAndTestTree( const TCut& cut, Int_t Ntrain, Int_t Ntest = -1 );

void PrepareTrainingAndTestTree( const TCut& cut, Int_t NsigTrain, Int_t NbkgTrain, Int_t NsigTest, Int_t NbkgTest,
void PrepareTrainingAndTestTree( const TCut& cut, Int_t NsigTrain, Int_t NbkgTrain, Int_t NsigTest, Int_t NbkgTest,
const TString& otherOpt="SplitMode=Random:!V" );

void PrepareTrainingAndTestTree( int foldNumber, Types::ETreeType tt );
Expand All @@ -169,11 +172,12 @@ namespace TMVA {
void ValidationKFoldSet();
std::vector<TTree*> SplitSets(TTree * oldTree, int seedNum, int numFolds);



TH1F* GetInputVariableHist( const TString& className, const TString& variableName, UInt_t numBin, const TString& processTrfs);
TH2* GetCorrelationMatrix( const TString& className );

private:


DataInputHandler& DataInput() { return *fDataInputHandler; }
DataSetInfo& DefaultDataSetInfo();
void SetInputTreesFromEventAssignTrees();
Expand All @@ -186,7 +190,7 @@ namespace TMVA {

DataSetManager* fDataSetManager; // DSMTEST


DataInputHandler* fDataInputHandler;

std::vector<TMVA::VariableTransformBase*> fDefaultTrfs; //! list of transformations on default DataSet
Expand All @@ -199,7 +203,7 @@ namespace TMVA {
TString fName; //! name, used as directory in output

// flag determining the way training and test data are assigned to DataLoader
enum DataAssignType { kUndefined = 0,
enum DataAssignType { kUndefined = 0,
kAssignTrees,
kAssignEvents };
DataAssignType fDataAssignType; //! flags for data assigning
Expand All @@ -214,7 +218,7 @@ namespace TMVA {
Int_t fATreeType; // type of event (=classIndex)
Float_t fATreeWeight; // weight of the event
Float_t* fATreeEvent; // event variables

Types::EAnalysisType fAnalysisType; //! the training type

protected:
Expand All @@ -225,4 +229,3 @@ namespace TMVA {
} // namespace TMVA

#endif

4 changes: 3 additions & 1 deletion tmva/tmva/inc/TMVA/MethodBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,12 @@ namespace TMVA {
class MethodCuts;
class MethodBoost;
class DataSetInfo;
class DataLoader;

class MethodBase : virtual public IMethod, public Configurable {

friend class Factory;
friend class DataLoader;

public:

Expand Down Expand Up @@ -182,7 +184,7 @@ namespace TMVA {
// helper function to set errors to -1
void NoErrorCalc(Double_t* const err, Double_t* const errUpper);

// signal/background classification response for all current set of data
// signal/background classification response for all current set of data
virtual std::vector<Double_t> GetMvaValues(Long64_t firstEvt = 0, Long64_t lastEvt = -1, Bool_t logProgress = false);


Expand Down
Loading