Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Update CNN and RNN tutorial to work also when ROOT is built without …
…Pymva
  • Loading branch information
lmoneta committed May 15, 2020
commit e6119aeb95a03f69acd18cdc4b210e609617c5be
19 changes: 15 additions & 4 deletions tutorials/tmva/TMVA_CNN_Classification.C
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,27 @@ void TMVA_CNN_Classification(std::vector<bool> opt = {1, 1, 1, 1})

bool writeOutputFile = true;

int num_threads = 0; // use default threads

TMVA::Tools::Instance();

// do enable MT running
ROOT::EnableImplicitMT();
if (num_threads >= 0) {
ROOT::EnableImplicitMT(num_threads);
if (num_threads > 0) gSystem->Setenv("OMP_NUM_THREADS", TString::Format("%d",num_threads));
}
else
gSystem->Setenv("OMP_NUM_THREADS", "1");

// for using Keras
std::cout << "Running with nthreads = " << ROOT::GetThreadPoolSize() << std::endl;

#ifdef R__HAS_PYMVA
gSystem->Setenv("KERAS_BACKEND", "tensorflow");
// for setting openblas in single thread on SWAN
gSystem->Setenv("OMP_NUM_THREADS", "1");
// for using Keras
TMVA::PyMethodBase::PyInitialize();
#else
useKerasCNN = false;
#endif

TFile *outputFile = nullptr;
if (writeOutputFile)
Expand Down
17 changes: 14 additions & 3 deletions tutorials/tmva/TMVA_RNN_Classification.C
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,26 @@ void TMVA_RNN_Classification(int use_type = 1)

const char *rnn_type = "RNN";

#ifdef R__HAS_PYMVA
TMVA::PyMethodBase::PyInitialize();
#else
useKeras = false;
#endif

int num_threads = 0; // use by default all threads
// do enable MT running
if (num_threads >= 0) {
ROOT::EnableImplicitMT(num_threads);
if (num_threads > 0) gSystem->Setenv("OMP_NUM_THREADS", TString::Format("%d",num_threads));
}
else
gSystem->Setenv("OMP_NUM_THREADS", "1");

ROOT::EnableImplicitMT();
TMVA::Config::Instance();

std::cout << "nthreads = " << ROOT::GetThreadPoolSize() << std::endl;
std::cout << "Running with nthreads = " << ROOT::GetThreadPoolSize() << std::endl;

TString inputFileName = "time_data_t10_d30.root";
// TString inputFileName = "/home/moneta/data/sample_images_32x32.gsoc.root";

bool fileExist = !gSystem->AccessPathName(inputFileName);

Expand Down