diff --git a/tree/treeplayer/inc/ROOT/TDFInterface.hxx b/tree/treeplayer/inc/ROOT/TDFInterface.hxx index f22c91e9c1e62..91bfef643be0e 100644 --- a/tree/treeplayer/inc/ROOT/TDFInterface.hxx +++ b/tree/treeplayer/inc/ROOT/TDFInterface.hxx @@ -1320,11 +1320,10 @@ private: void DefineDSColumnHelper(std::string_view name, TLoopManager &lm) { assert(fDataSource != nullptr); - const auto nSlots = fProxiedPtr->GetNSlots(); - auto readers = fDataSource->GetColumnReaders(name, nSlots); - auto getValuePtr = [readers](unsigned int slot) { return *readers[slot]; }; - using NewCol_t = TDFDetail::TCustomColumn; - lm.Book(std::make_shared(name, std::move(getValuePtr), ColumnNames_t{}, &lm)); + auto readers = fDataSource->GetColumnReaders(name); + auto getValue = [readers](unsigned int slot) { return **readers[slot]; }; + using NewCol_t = TDFDetail::TCustomColumn; + lm.Book(std::make_shared(name, std::move(getValue), ColumnNames_t{}, &lm)); lm.AddDataSourceColumn(name); } diff --git a/tree/treeplayer/inc/ROOT/TDFNodes.hxx b/tree/treeplayer/inc/ROOT/TDFNodes.hxx index 5b0d8a1099bac..b8ee29c81ebb8 100644 --- a/tree/treeplayer/inc/ROOT/TDFNodes.hxx +++ b/tree/treeplayer/inc/ROOT/TDFNodes.hxx @@ -369,13 +369,41 @@ public: unsigned int GetNSlots() const { return fNSlots; } }; -template +template class TCustomColumn final : public TCustomColumnBase { + /// A type that throws if default-constructed or assigned to (which is all TCustomColumn will do with it) + struct AbortIfUsed { + template + void operator=(const T&) { + assert(false && "This `Define`d column returns a non-assignable type. This is not supported."); + } + }; + using FunParamTypes_t = typename CallableTraits::arg_types; using BranchTypes_t = typename TDFInternal::RemoveFirstParameterIf::type; using TypeInd_t = TDFInternal::GenStaticSeq_t; - using ret_type = typename CallableTraits::ret_type; + // We need UpdateHelper to compile even for non-assignable or non-default-constructible return types of the custom + // column expression. + // In particular the expression `fLastResults[0] = fExpression(columns...)` would not compile if the returned type + // did not have an assignment operator, and `new ret_type()` would not compile without a default constructor. So we + // switch these problematic return types with `AbortIfUsed`, which breaks on an assertion if someone ever tries to + // actually default-construct it, or assign to it. + // This workaround is required because the code-path that reads values from a data-source always triggers compilation + // of `TCustomColumn`s of each type that a node takes in input, even though a data-source is not actually present at + // runtime, and even though that type is non-assignable/non-default-constructible. + // + // The workaround can go away if/when we support data-source columns of non-assignable types (or at least when we + // avoid doing those assignments ourselves) and if we decide to drop support of non-default-constructible types. + // + // An alternative solution might be delegating calls to `DefineDataSourceColumns` to specialized functions chosen + // at TInterface construction time (only when TDataSource is present the specialized function would actually do + // something). This way the compiler would try to compile TCustomColumns for all inputs of all nodes only when + // a TDataSource is present. + using TrueRetType_t = typename CallableTraits::ret_type; + using ret_type = typename std::conditional::value && + std::is_default_constructible::value, + TrueRetType_t, AbortIfUsed>::type; // Avoid instantiating vector as `operator[]` returns temporaries in that case. Use std::deque instead. using ValuesPerSlot_t = typename std::conditional::value, std::deque, std::vector>::type; @@ -403,11 +431,7 @@ public: fImplPtr->GetBookedColumns(), TypeInd_t()); } - void *GetValuePtr(unsigned int slot) final - { - // could be nicely made a constexpr if instead - return GetValuePtrImpl(slot, std::integral_constant()); - } + void *GetValuePtr(unsigned int slot) final { return static_cast(&fLastResults[slot]); } void Update(unsigned int slot, Long64_t entry) final { @@ -441,18 +465,6 @@ public: } void ClearValueReaders(unsigned int slot) final { ResetTDFValueTuple(fValues[slot], TypeInd_t()); } -private: - // user-defined custom columns are stored by value, take their address - void *GetValuePtrImpl(unsigned int slot, std::false_type /*IsDataSourceColumn*/) - { - return static_cast(&fLastResults[slot]); - } - - // data-source columns are stored by pointer, just cast it to void* - void *GetValuePtrImpl(unsigned int slot, std::true_type /*IsDataSourceColumn*/) - { - return static_cast(fLastResults[slot]); - } }; class TFilterBase { diff --git a/tree/treeplayer/inc/ROOT/TDataSource.hxx b/tree/treeplayer/inc/ROOT/TDataSource.hxx index 39033bf269c32..c8d3f7ef20430 100644 --- a/tree/treeplayer/inc/ROOT/TDataSource.hxx +++ b/tree/treeplayer/inc/ROOT/TDataSource.hxx @@ -30,9 +30,9 @@ public: virtual std::string GetTypeName(std::string_view) const = 0; /// Called at most once per column by TDF. Return vector of pointers to pointers to column values - one per slot. template - std::vector GetColumnReaders(std::string_view name, unsigned int nSlots) + std::vector GetColumnReaders(std::string_view name) { - auto typeErasedVec = GetColumnReadersImpl(name, nSlots, typeid(T)); + auto typeErasedVec = GetColumnReadersImpl(name, typeid(T)); std::vector typedVec(typeErasedVec.size()); std::transform(typeErasedVec.begin(), typeErasedVec.end(), typedVec.begin(), [](void *p) { return static_cast(p); }); @@ -43,6 +43,9 @@ public: virtual const std::vector> &GetEntryRanges() const = 0; /// Different threads will loop over different ranges and will pass different "slot" values. virtual void SetEntry(ULong64_t entry, unsigned int slot) = 0; + /// Method to set the number of slots. Some implementations may rely on this + /// information for optimisation purposes. + virtual void SetNSlots(unsigned int nSlots) = 0; /// Convenience method called at the start of each task, before processing a range of entries. /// DataSources can implement it if needed (does nothing by default). /// firstEntry is the first entry of the range that the task will process. @@ -51,7 +54,7 @@ public: protected: /// type-erased vector of pointers to pointers to column values - one per slot virtual std::vector - GetColumnReadersImpl(std::string_view name, unsigned int nSlots, const std::type_info &) = 0; + GetColumnReadersImpl(std::string_view name, const std::type_info &) = 0; }; } // ns TDF diff --git a/tree/treeplayer/inc/ROOT/TRootDS.hxx b/tree/treeplayer/inc/ROOT/TRootDS.hxx new file mode 100644 index 0000000000000..22d31d3530018 --- /dev/null +++ b/tree/treeplayer/inc/ROOT/TRootDS.hxx @@ -0,0 +1,42 @@ +#ifndef ROOT_TROOTTDS +#define ROOT_TROOTTDS + +#include "ROOT/TDataSource.hxx" +#include + +#include + +namespace ROOT { +namespace Experimental { +namespace TDF { + +class TRootDS final : public ROOT::Experimental::TDF::TDataSource { +private: + unsigned int fNSlots = 0U; + std::string fTreeName; + std::string fFileNameGlob; + mutable TChain fModelChain; // Mutable needed for getting the column type name + std::vector fListOfBranches; + std::vector> fEntryRanges; + std::vector> fBranchAddresses; // first container-> slot, second -> column; + std::vector> fChains; + + void InitAddresses() {} + std::vector GetColumnReadersImpl(std::string_view, const std::type_info &); + +public: + TRootDS(std::string_view treeName, std::string_view fileNameGlob); + ~TRootDS(); + std::string GetTypeName(std::string_view colName) const; + const std::vector &GetColumnNames() const; + bool HasColumn(std::string_view colName) const; + void InitSlot(unsigned int slot, ULong64_t firstEntry); + const std::vector> &GetEntryRanges() const; + void SetEntry(ULong64_t entry, unsigned int slot); + void SetNSlots(unsigned int nSlots); +}; +} +} +} + +#endif diff --git a/tree/treeplayer/inc/ROOT/TTrivialDS.hxx b/tree/treeplayer/inc/ROOT/TTrivialDS.hxx new file mode 100644 index 0000000000000..f4c0186604aca --- /dev/null +++ b/tree/treeplayer/inc/ROOT/TTrivialDS.hxx @@ -0,0 +1,33 @@ +#ifndef ROOT_TTRIVIALTDS +#define ROOT_TTRIVIALTDS + +#include "ROOT/TDataSource.hxx" + +namespace ROOT { +namespace Experimental { +namespace TDF { + +class TTrivialDS final : public ROOT::Experimental::TDF::TDataSource { +private: + unsigned int fNSlots = 0U; + ULong64_t fSize = 0ULL; + std::vector> fEntryRanges; + std::vector fColNames{"col0"}; + std::vector fCounter; + std::vector fCounterAddr; + std::vector GetColumnReadersImpl(std::string_view name, const std::type_info &); + +public: + TTrivialDS(ULong64_t size); + ~TTrivialDS(); + const std::vector &GetColumnNames() const; + bool HasColumn(std::string_view colName) const; + std::string GetTypeName(std::string_view) const; + const std::vector> &GetEntryRanges() const; + void SetEntry(ULong64_t entry, unsigned int slot); + void SetNSlots(unsigned int nSlots); +}; +} +} +} +#endif diff --git a/tree/treeplayer/src/TDFNodes.cxx b/tree/treeplayer/src/TDFNodes.cxx index 3833edb40bc2b..f07e1a9454464 100644 --- a/tree/treeplayer/src/TDFNodes.cxx +++ b/tree/treeplayer/src/TDFNodes.cxx @@ -127,6 +127,8 @@ TLoopManager::TLoopManager(std::unique_ptr ds, const ColumnNames_t : fDefaultColumns(defaultBranches), fNSlots(TDFInternal::GetNSlots()), fLoopType(ELoopType::kDataSource), fDataSource(std::move(ds)) { + const auto nSlots = TDFInternal::GetNSlots(); + fDataSource->SetNSlots( 0U == nSlots ? 1U : nSlots); } /// Run event loop with no source files, in parallel. diff --git a/tree/treeplayer/src/TRootDS.cxx b/tree/treeplayer/src/TRootDS.cxx new file mode 100644 index 0000000000000..f152ee8e92ddc --- /dev/null +++ b/tree/treeplayer/src/TRootDS.cxx @@ -0,0 +1,131 @@ +#include +#include +#include +#include + +#include +#include + +namespace ROOT { +namespace Experimental { +namespace TDF { + +std::vector TRootDS::GetColumnReadersImpl(std::string_view name, const std::type_info &) +{ + + const auto &colNames = GetColumnNames(); + + if (fBranchAddresses.empty()) { + auto nColumns = colNames.size(); + // Initialise the entire set of addresses + fBranchAddresses.resize(nColumns, std::vector(fNSlots)); + } + + const auto index = std::distance(colNames.begin(), std::find(colNames.begin(), colNames.end(), name)); + std::vector ret(fNSlots); + for (auto slot : ROOT::TSeqU(fNSlots)) { + ret[slot] = (void *)&fBranchAddresses[index][slot]; + } + return ret; +} + +TRootDS::TRootDS(std::string_view treeName, std::string_view fileNameGlob) + : fTreeName(treeName), fFileNameGlob(fileNameGlob), fModelChain(std::string(treeName).c_str()) +{ + fModelChain.Add(fFileNameGlob.c_str()); + + auto &lob = *fModelChain.GetListOfBranches(); + fListOfBranches.resize(lob.GetEntries()); + std::transform(lob.begin(), lob.end(), fListOfBranches.begin(), [](TObject *o) { return o->GetName(); }); +} + +TRootDS::~TRootDS() +{ +} + +std::string TRootDS::GetTypeName(std::string_view colName) const +{ + if (!HasColumn(colName)) { + std::string e = "The dataset does not have column "; + e += colName; + throw std::runtime_error(e); + } + // TODO: we need to factor out the routine for the branch alone... + // Maybe a cache for the names? + auto typeName = ROOT::Internal::TDF::ColumnName2ColumnTypeName(std::string(colName).c_str(), &fModelChain, + nullptr /*TCustomColumnBase here*/); + // We may not have yet loaded the library where the dictionary of this type + // is + TClass::GetClass(typeName.c_str()); + return typeName; +} + +const std::vector &TRootDS::GetColumnNames() const +{ + return fListOfBranches; +} + +bool TRootDS::HasColumn(std::string_view colName) const +{ + if (!fListOfBranches.empty()) + GetColumnNames(); + return fListOfBranches.end() != std::find(fListOfBranches.begin(), fListOfBranches.end(), colName); +} + +void TRootDS::InitSlot(unsigned int slot, ULong64_t firstEntry) +{ + auto chain = new TChain(fTreeName.c_str()); + fChains[slot].reset(chain); + chain->Add(fFileNameGlob.c_str()); + chain->GetEntry(firstEntry); + for (auto i : ROOT::TSeqU(fListOfBranches.size())) { + auto colName = fListOfBranches[i].c_str(); + auto &addr = fBranchAddresses[i][slot]; + auto typeName = GetTypeName(colName); + auto isClass = nullptr != TClass::GetClass(typeName.c_str()); + if (isClass) { + chain->SetBranchAddress(colName, &addr); + } else { + if (!addr) { + addr = new double(); // who frees this :) ? + } + chain->SetBranchAddress(colName, addr); + } + } +} + +const std::vector> &TRootDS::GetEntryRanges() const +{ + if (fEntryRanges.empty()) { + throw std::runtime_error("No ranges are available. Did you set the number of slots?"); + } + return fEntryRanges; +} + +void TRootDS::SetEntry(ULong64_t entry, unsigned int slot) +{ + fChains[slot]->GetEntry(entry); +} + +void TRootDS::SetNSlots(unsigned int nSlots) +{ + assert(0U == fNSlots && "Setting the number of slots even if the number of slots is different from zero."); + + fNSlots = nSlots; + fChains.resize(fNSlots); + auto nentries = fModelChain.GetEntries(); + auto chunkSize = nentries / fNSlots; + auto reminder = 1U == fNSlots ? 0 : nentries % fNSlots; + auto start = 0UL; + auto end = 0UL; + for (auto i : ROOT::TSeqU(fNSlots)) { + start = end; + end += chunkSize; + fEntryRanges.emplace_back(start, end); + (void)i; + } + fEntryRanges.back().second += reminder; +} +} +} +} diff --git a/tree/treeplayer/src/TTrivialDS.cxx b/tree/treeplayer/src/TTrivialDS.cxx new file mode 100644 index 0000000000000..63794f828f3fa --- /dev/null +++ b/tree/treeplayer/src/TTrivialDS.cxx @@ -0,0 +1,77 @@ +#include +#include +#include + +namespace ROOT { +namespace Experimental { +namespace TDF { + +std::vector TTrivialDS::GetColumnReadersImpl(std::string_view, const std::type_info &) +{ + std::vector ret; + for (auto i : ROOT::TSeqU(fNSlots)) { + fCounterAddr[i] = &fCounter[i]; + ret.emplace_back((void *)(&fCounterAddr[i])); + } + return ret; +} + +TTrivialDS::TTrivialDS(ULong64_t size) : fSize(size) +{ +} + +TTrivialDS::~TTrivialDS() +{ +} + +const std::vector &TTrivialDS::GetColumnNames() const +{ + return fColNames; +} + +bool TTrivialDS::HasColumn(std::string_view colName) const +{ + return colName == fColNames[0]; +} + +std::string TTrivialDS::GetTypeName(std::string_view) const +{ + return "ULong64_t"; +} + +const std::vector> &TTrivialDS::GetEntryRanges() const +{ + if (fEntryRanges.empty()) { + throw std::runtime_error("No ranges are available. Did you set the number of slots?"); + } + return fEntryRanges; +} + +void TTrivialDS::SetEntry(ULong64_t entry, unsigned int slot) +{ + fCounter[slot] = entry; +} + +void TTrivialDS::SetNSlots(unsigned int nSlots) +{ + assert(0U == fNSlots && "Setting the number of slots even if the number of slots is different from zero."); + + fNSlots = nSlots; + fCounter.resize(fNSlots); + fCounterAddr.resize(fNSlots); + + auto chunkSize = fSize / fNSlots; + auto start = 0UL; + auto end = 0UL; + for (auto i : ROOT::TSeqUL(fNSlots)) { + start = end; + end += chunkSize; + fEntryRanges.emplace_back(start, end); + (void)i; + } + // TODO: redistribute reminder to all slots + fEntryRanges.back().second += fSize % fNSlots; +} +} +} +} diff --git a/tree/treeplayer/test/CMakeLists.txt b/tree/treeplayer/test/CMakeLists.txt index 74616a55f9703..1716fe5094403 100644 --- a/tree/treeplayer/test/CMakeLists.txt +++ b/tree/treeplayer/test/CMakeLists.txt @@ -7,3 +7,5 @@ ROOT_ADD_GTEST(dataframe_regression dataframe/dataframe_regression.cxx LIBRARIES ROOT_ADD_GTEST(dataframe_interface dataframe/dataframe_interface.cxx LIBRARIES TreePlayer) ROOT_ADD_GTEST(dataframe_utils dataframe/dataframe_utils.cxx LIBRARIES TreePlayer) ROOT_ADD_GTEST(dataframe_nodes dataframe/dataframe_nodes.cxx LIBRARIES TreePlayer) +ROOT_ADD_GTEST(datasource_trivial dataframe/datasource_trivial.cxx LIBRARIES TreePlayer) +ROOT_ADD_GTEST(datasource_root dataframe/datasource_root.cxx LIBRARIES TreePlayer) \ No newline at end of file diff --git a/tree/treeplayer/test/dataframe/datasource_root.cxx b/tree/treeplayer/test/dataframe/datasource_root.cxx new file mode 100644 index 0000000000000..0d633196c4b2b --- /dev/null +++ b/tree/treeplayer/test/dataframe/datasource_root.cxx @@ -0,0 +1,108 @@ +#include "ROOT/TRootDS.hxx" +#include "TGraph.h" + +#include "gtest/gtest.h" + +#include + +using namespace ROOT::Experimental; +using namespace ROOT::Experimental::TDF; + +auto fileName0 = "TRootTDS_input_0.root"; +auto fileName1 = "TRootTDS_input_1.root"; +auto fileName2 = "TRootTDS_input_2.root"; +auto fileGlob = "TRootTDS_input_*.root"; +auto treeName = "t"; + +TEST(TRootDS, GenerateData) +{ + int i = 0; + TGraph g; + for (auto &&fileName : {fileName0, fileName1, fileName2}) { + TDataFrame tdf(10); + tdf.Define("i", [&i]() { return i++; }) + .Define("g", + [&g, &i]() { + g.SetPoint(i - 1, i, i); + return g; + }) + .Snapshot(treeName, fileName, {"i", "g"}); + } +} + +TEST(TRootDS, ColTypeNames) +{ + TRootDS tds(treeName, fileGlob); + tds.SetNSlots(1); + + auto colNames = tds.GetColumnNames(); + + EXPECT_TRUE(tds.HasColumn("i")); + EXPECT_TRUE(tds.HasColumn("g")); + EXPECT_FALSE(tds.HasColumn("bla")); + + EXPECT_STREQ("i", colNames[0].c_str()); + EXPECT_STREQ("g", colNames[1].c_str()); + + EXPECT_STREQ("int", tds.GetTypeName("i").c_str()); + EXPECT_STREQ("TGraph", tds.GetTypeName("g").c_str()); +} + +TEST(TRootTDS, EntryRanges) +{ + TRootDS tds(treeName, fileGlob); + tds.SetNSlots(3U); + + // Still dividing in equal parts... + auto ranges = tds.GetEntryRanges(); + + EXPECT_EQ(3U, ranges.size()); + EXPECT_EQ(0U, ranges[0].first); + EXPECT_EQ(10U, ranges[0].second); + EXPECT_EQ(10U, ranges[1].first); + EXPECT_EQ(20U, ranges[1].second); + EXPECT_EQ(20U, ranges[2].first); + EXPECT_EQ(30U, ranges[2].second); +} + +TEST(TRootTDS, ColumnReaders) +{ + TRootDS tds(treeName, fileGlob); + const auto nSlots = 3U; + tds.SetNSlots(nSlots); + auto vals = tds.GetColumnReaders("i"); + auto ranges = tds.GetEntryRanges(); + auto slot = 0U; + for (auto &&range : ranges) { + tds.InitSlot(slot, range.first); + for (auto i : ROOT::TSeq(range.first, range.second)) { + tds.SetEntry(i, slot); + auto val = **vals[slot]; + EXPECT_EQ(i, val); + } + slot++; + } +} + +TEST(TRootTDS, SetNSlotsTwice) +{ + auto theTest = []() { + TRootDS tds(treeName, fileGlob); + tds.SetNSlots(1); + tds.SetNSlots(1); + }; + ASSERT_DEATH(theTest(), "Setting the number of slots even if the number of slots is different from zero."); +} + +TEST(TRootTDS, FromATDF) +{ + std::unique_ptr tds(new TRootDS(treeName, fileGlob)); + TDataFrame tdf(std::move(tds)); + auto max = tdf.Max("i"); + auto min = tdf.Min("i"); + auto c = tdf.Count(); + + EXPECT_EQ(30U, *c); + EXPECT_DOUBLE_EQ(29., *max); + EXPECT_DOUBLE_EQ(0., *min); +} diff --git a/tree/treeplayer/test/dataframe/datasource_trivial.cxx b/tree/treeplayer/test/dataframe/datasource_trivial.cxx new file mode 100644 index 0000000000000..29bc2b553eb7f --- /dev/null +++ b/tree/treeplayer/test/dataframe/datasource_trivial.cxx @@ -0,0 +1,81 @@ +#include +#include +#include + +#include "gtest/gtest.h" + +using namespace ROOT::Experimental; +using namespace ROOT::Experimental::TDF; + +TEST(TTrivialDS, ColTypeNames) +{ + TTrivialDS tds(32); + tds.SetNSlots(1); + + auto colName = tds.GetColumnNames()[0]; // We know it's one. + EXPECT_STREQ("col0", colName.c_str()); + EXPECT_STREQ("ULong64_t", tds.GetTypeName("col0").c_str()); + + EXPECT_TRUE(tds.HasColumn("col0")); + EXPECT_FALSE(tds.HasColumn("col1")); +} + +TEST(TTrivialDS, EntryRanges) +{ + TTrivialDS tds(32); + const auto nSlots = 4U; + tds.SetNSlots(nSlots); + + auto ranges = tds.GetEntryRanges(); + + EXPECT_EQ(4U, ranges.size()); + EXPECT_EQ(0U, ranges[0].first); + EXPECT_EQ(8U, ranges[0].second); + EXPECT_EQ(8U, ranges[1].first); + EXPECT_EQ(16U, ranges[1].second); + EXPECT_EQ(16U, ranges[2].first); + EXPECT_EQ(24U, ranges[2].second); + EXPECT_EQ(24U, ranges[3].first); + EXPECT_EQ(32U, ranges[3].second); +} + +TEST(TTrivialDS, ColumnReaders) +{ + TTrivialDS tds(32); + const auto nSlots = 4U; + tds.SetNSlots(nSlots); + auto vals = tds.GetColumnReaders("col0"); + auto ranges = tds.GetEntryRanges(); + auto slot = 0U; + for (auto &&range : ranges) { + for (auto i : ROOT::TSeq(range.first, range.second)) { + tds.SetEntry(i, slot); + auto val = **vals[slot]; + EXPECT_EQ(i, val); + } + slot++; + } +} + +TEST(TTrivialDS, SetNSlotsTwice) +{ + auto theTest = []() { + TTrivialDS tds(1); + tds.SetNSlots(1); + tds.SetNSlots(1); + }; + ASSERT_DEATH(theTest(), "Setting the number of slots even if the number of slots is different from zero."); +} + +TEST(TTrivialDS, FromATDF) +{ + std::unique_ptr tds(new TTrivialDS(32)); + TDataFrame tdf(std::move(tds)); + auto max = tdf.Max("col0"); + auto min = tdf.Min("col0"); + auto c = tdf.Count(); + + EXPECT_EQ(32U, *c); + EXPECT_DOUBLE_EQ(31., *max); + EXPECT_DOUBLE_EQ(0., *min); +}