Skip to content
Prev Previous commit
[TDF] Refactor data sources in headers and implementation files
and adapt tests to the new structure.
  • Loading branch information
dpiparo committed Sep 18, 2017
commit bd2a7349c29e764e596b9f29d1a1fa387719db7c
132 changes: 13 additions & 119 deletions tree/treeplayer/inc/ROOT/TRootDS.hxx
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
#ifndef ROOT_TROOTTDS
#define ROOT_TROOTTDS

#include "ROOT/TDataFrame.hxx"
#include "ROOT/TDataSource.hxx"
#include <TChain.h>

#include "ROOT/TSeq.hxx"
#include "TChain.h"
#include "TROOT.h"

#include <algorithm>
#include <vector>
#include <memory>

namespace ROOT {
namespace Experimental {
Expand All @@ -18,128 +13,27 @@ namespace TDF {
class TRootDS final : public ROOT::Experimental::TDF::TDataSource {
private:
unsigned int fNSlots = 0U;
mutable TChain fModelChain; // Mutable needed for getting the column type name
std::string fTreeName;
std::string fFileNameGlob;
mutable TChain fModelChain; // Mutable needed for getting the column type name
std::vector<std::string> fListOfBranches;
std::vector<std::pair<ULong64_t, ULong64_t>> fEntryRanges;
std::vector<std::vector<void *>> fBranchAddresses; // first container-> slot, second -> column;

std::vector<std::unique_ptr<TChain>> fChains;

void InitAddresses() {}

std::vector<void *> GetColumnReadersImpl(std::string_view name, const std::type_info &)
{

const auto &colNames = GetColumnNames();

if (fBranchAddresses.empty()) {
auto nColumns = colNames.size();
// Initialise the entire set of addresses
fBranchAddresses.resize(nColumns, std::vector<void *>(fNSlots));
}

const auto index = std::distance(colNames.begin(), std::find(colNames.begin(), colNames.end(), name));
std::vector<void *> ret(fNSlots);
for (auto slot : ROOT::TSeqU(fNSlots)) {
ret[slot] = (void *)&fBranchAddresses[index][slot];
}
return ret;
}

// This is not even a method...
bool IsClass(const std::string &typeName) { return nullptr != TClass::GetClass(typeName.c_str()); }
std::vector<void *> GetColumnReadersImpl(std::string_view, const std::type_info &);

public:
TRootDS(std::string_view treeName, std::string_view fileNameGlob)
: fTreeName(treeName), fFileNameGlob(fileNameGlob), fModelChain(std::string(treeName).c_str())
{
fModelChain.Add(fFileNameGlob.c_str());

auto &lob = *fModelChain.GetListOfBranches();
fListOfBranches.resize(lob.GetEntries());
std::transform(lob.begin(), lob.end(), fListOfBranches.begin(), [](TObject *o) { return o->GetName(); });
}

~TRootDS(){};

std::string GetTypeName(std::string_view colName) const
{
if (!HasColumn(colName)) {
std::string e = "The dataset does not have column ";
e += colName;
throw std::runtime_error(e);
}
// TODO: we need to factor out the routine for the branch alone...
// Maybe a cache for the names?
auto typeName = ROOT::Internal::TDF::ColumnName2ColumnTypeName(std::string(colName).c_str(), &fModelChain,
nullptr /*TCustomColumnBase here*/);
// We may not have yet loaded the library where the dictionary of this type
// is
TClass::GetClass(typeName.c_str());
return typeName;
}

const std::vector<std::string> &GetColumnNames() const { return fListOfBranches; }

bool HasColumn(std::string_view colName) const
{
if (!fListOfBranches.empty())
GetColumnNames();
return fListOfBranches.end() != std::find(fListOfBranches.begin(), fListOfBranches.end(), colName);
}

void InitSlot(unsigned int slot, ULong64_t firstEntry)
{
auto chain = new TChain(fTreeName.c_str());
fChains[slot].reset(chain);
chain->Add(fFileNameGlob.c_str());
chain->GetEntry(firstEntry);
auto &theseBranchAddresses = fBranchAddresses[slot];
for (auto i : ROOT::TSeqU(fListOfBranches.size())) {
auto colName = fListOfBranches[i].c_str();
auto &addr = fBranchAddresses[i][slot];
if (IsClass(GetTypeName(colName))) {
gROOT->GetClass(GetTypeName(colName).c_str());
chain->SetBranchAddress(colName, &addr);
} else {
if (!addr) {
addr = new double(); // who frees this :) ?
}
chain->SetBranchAddress(colName, addr);
}
}
}

const std::vector<std::pair<ULong64_t, ULong64_t>> &GetEntryRanges() const
{
if (fEntryRanges.empty()) {
throw std::runtime_error("No ranges are available. Did you set the number of slots?");
}
return fEntryRanges;
}

void SetEntry(ULong64_t entry, unsigned int slot) { fChains[slot]->GetEntry(entry); }

void SetNSlots(unsigned int nSlots)
{
assert(0U == fNSlots && "Setting the number of slots even if the number of slots is different from zero.");

fNSlots = nSlots;
fChains.resize(fNSlots);
auto nentries = fModelChain.GetEntries();
auto chunkSize = nentries / fNSlots;
auto reminder = 1U == fNSlots ? 0 : nentries % fNSlots;
auto start = 0UL;
auto end = 0UL;
for (auto i : ROOT::TSeqU(fNSlots)) {
start = end;
end += chunkSize;
fEntryRanges.emplace_back(start, end);
}
fEntryRanges.back().second += reminder;
}
TRootDS(std::string_view treeName, std::string_view fileNameGlob);
~TRootDS();
std::string GetTypeName(std::string_view colName) const;
const std::vector<std::string> &GetColumnNames() const;
bool HasColumn(std::string_view colName) const;
void InitSlot(unsigned int slot, ULong64_t firstEntry);
const std::vector<std::pair<ULong64_t, ULong64_t>> &GetEntryRanges() const;
void SetEntry(ULong64_t entry, unsigned int slot);
void SetNSlots(unsigned int nSlots);
};
}
}
Expand Down
59 changes: 9 additions & 50 deletions tree/treeplayer/inc/ROOT/TTrivialDS.hxx
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
#ifndef ROOT_TTRIVIALTDS
#define ROOT_TTRIVIALTDS

#include "ROOT/TDataFrame.hxx"
#include "ROOT/TDataSource.hxx"
#include "ROOT/TSeq.hxx"

namespace ROOT {
namespace Experimental {
Expand All @@ -17,58 +15,19 @@ private:
std::vector<std::string> fColNames{"col0"};
std::vector<ULong64_t> fCounter;
std::vector<ULong64_t *> fCounterAddr;
std::vector<void *> GetColumnReadersImpl(std::string_view name, const std::type_info &)
{
std::vector<void *> ret;
for (auto i : ROOT::TSeqU(fNSlots)) {
fCounterAddr[i] = &fCounter[i];
ret.emplace_back((void *)(&fCounterAddr[i]));
}
return ret;
}
std::vector<void *> GetColumnReadersImpl(std::string_view name, const std::type_info &);

public:
TTrivialDS(ULong64_t size) : fSize(size) {}

~TTrivialDS() {}

const std::vector<std::string> &GetColumnNames() const { return fColNames; }

bool HasColumn(std::string_view colName) const { return colName == fColNames[0]; }

std::string GetTypeName(std::string_view) const { return "ULong64_t"; }

const std::vector<std::pair<ULong64_t, ULong64_t>> &GetEntryRanges() const
{
if (fEntryRanges.empty()) {
throw std::runtime_error("No ranges are available. Did you set the number of slots?");
}
return fEntryRanges;
}
void SetEntry(ULong64_t entry, unsigned int slot) { fCounter[slot] = entry; }

void SetNSlots(unsigned int nSlots)
{
assert(0U == fNSlots && "Setting the number of slots even if the number of slots is different from zero.");

fNSlots = nSlots;
fCounter.resize(fNSlots);
fCounterAddr.resize(fNSlots);

auto chunkSize = fSize / fNSlots;
auto start = 0UL;
auto end = 0UL;
for (auto i : ROOT::TSeqUL(fNSlots)) {
start = end;
end += chunkSize;
fEntryRanges.emplace_back(start, end);
}
// TODO: redistribute reminder to all slots
fEntryRanges.back().second += fSize % fNSlots;
};
TTrivialDS(ULong64_t size);
~TTrivialDS();
const std::vector<std::string> &GetColumnNames() const;
bool HasColumn(std::string_view colName) const;
std::string GetTypeName(std::string_view) const;
const std::vector<std::pair<ULong64_t, ULong64_t>> &GetEntryRanges() const;
void SetEntry(ULong64_t entry, unsigned int slot);
void SetNSlots(unsigned int nSlots);
};
}
}
}

#endif
131 changes: 131 additions & 0 deletions tree/treeplayer/src/TRootDS.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
#include <TClass.h>
#include <ROOT/TDFUtils.hxx>
#include <ROOT/TRootDS.hxx>
#include <ROOT/TSeq.hxx>

#include <algorithm>
#include <vector>

namespace ROOT {
namespace Experimental {
namespace TDF {

std::vector<void *> TRootDS::GetColumnReadersImpl(std::string_view name, const std::type_info &)
{

const auto &colNames = GetColumnNames();

if (fBranchAddresses.empty()) {
auto nColumns = colNames.size();
// Initialise the entire set of addresses
fBranchAddresses.resize(nColumns, std::vector<void *>(fNSlots));
}

const auto index = std::distance(colNames.begin(), std::find(colNames.begin(), colNames.end(), name));
std::vector<void *> ret(fNSlots);
for (auto slot : ROOT::TSeqU(fNSlots)) {
ret[slot] = (void *)&fBranchAddresses[index][slot];
}
return ret;
}

TRootDS::TRootDS(std::string_view treeName, std::string_view fileNameGlob)
: fTreeName(treeName), fFileNameGlob(fileNameGlob), fModelChain(std::string(treeName).c_str())
{
fModelChain.Add(fFileNameGlob.c_str());

auto &lob = *fModelChain.GetListOfBranches();
fListOfBranches.resize(lob.GetEntries());
std::transform(lob.begin(), lob.end(), fListOfBranches.begin(), [](TObject *o) { return o->GetName(); });
}

TRootDS::~TRootDS()
{
}

std::string TRootDS::GetTypeName(std::string_view colName) const
{
if (!HasColumn(colName)) {
std::string e = "The dataset does not have column ";
e += colName;
throw std::runtime_error(e);
}
// TODO: we need to factor out the routine for the branch alone...
// Maybe a cache for the names?
auto typeName = ROOT::Internal::TDF::ColumnName2ColumnTypeName(std::string(colName).c_str(), &fModelChain,
nullptr /*TCustomColumnBase here*/);
// We may not have yet loaded the library where the dictionary of this type
// is
TClass::GetClass(typeName.c_str());
return typeName;
}

const std::vector<std::string> &TRootDS::GetColumnNames() const
{
return fListOfBranches;
}

bool TRootDS::HasColumn(std::string_view colName) const
{
if (!fListOfBranches.empty())
GetColumnNames();
return fListOfBranches.end() != std::find(fListOfBranches.begin(), fListOfBranches.end(), colName);
}

void TRootDS::InitSlot(unsigned int slot, ULong64_t firstEntry)
{
auto chain = new TChain(fTreeName.c_str());
fChains[slot].reset(chain);
chain->Add(fFileNameGlob.c_str());
chain->GetEntry(firstEntry);
for (auto i : ROOT::TSeqU(fListOfBranches.size())) {
auto colName = fListOfBranches[i].c_str();
auto &addr = fBranchAddresses[i][slot];
auto typeName = GetTypeName(colName);
auto isClass = nullptr != TClass::GetClass(typeName.c_str());
if (isClass) {
chain->SetBranchAddress(colName, &addr);
} else {
if (!addr) {
addr = new double(); // who frees this :) ?
}
chain->SetBranchAddress(colName, addr);
}
}
}

const std::vector<std::pair<ULong64_t, ULong64_t>> &TRootDS::GetEntryRanges() const
{
if (fEntryRanges.empty()) {
throw std::runtime_error("No ranges are available. Did you set the number of slots?");
}
return fEntryRanges;
}

void TRootDS::SetEntry(ULong64_t entry, unsigned int slot)
{
fChains[slot]->GetEntry(entry);
}

void TRootDS::SetNSlots(unsigned int nSlots)
{
assert(0U == fNSlots && "Setting the number of slots even if the number of slots is different from zero.");

fNSlots = nSlots;
fChains.resize(fNSlots);
auto nentries = fModelChain.GetEntries();
auto chunkSize = nentries / fNSlots;
auto reminder = 1U == fNSlots ? 0 : nentries % fNSlots;
auto start = 0UL;
auto end = 0UL;
for (auto i : ROOT::TSeqU(fNSlots)) {
start = end;
end += chunkSize;
fEntryRanges.emplace_back(start, end);
(void)i;
}
fEntryRanges.back().second += reminder;
}
}
}
}
Loading