Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
82e56f4
Stratified aggregation
guillembartrina Oct 19, 2023
ae94fe6
Add some tests and examples, fix broken tests
guillembartrina Oct 19, 2023
9c84a88
Handle anonymous variables in grouping atoms, add a few examples
guillembartrina Oct 22, 2023
9fab8d6
Simplify Grouping volcano operator, add constants generated by aggreg…
guillembartrina Oct 23, 2023
84cf40f
Remove unused AST parameter
guillembartrina Oct 26, 2023
5a746a5
Built-in constraints
guillembartrina Oct 30, 2023
21668b7
Built-in constraints
guillembartrina Oct 30, 2023
e5e71c5
Merge branch 'builtin-constraints' of github.com:guillembartrina/cara…
guillembartrina Oct 30, 2023
80f636c
Move creation of static operations, finish StagedSnippet
guillembartrina Nov 1, 2023
eb24a37
Built-in constraints
guillembartrina Oct 30, 2023
5f3088c
Move creation of static operations
guillembartrina Nov 1, 2023
2176b5a
Merge branch 'builtin-constraints' of github.com:guillembartrina/cara…
guillembartrina Nov 1, 2023
0f323c5
Optimize negation: avoid combinatorial explosion
guillembartrina Nov 5, 2023
9bb0349
Add synthetic test
guillembartrina Nov 5, 2023
e062d6e
Add bugfix comment
guillembartrina Nov 7, 2023
5a18074
Optimize negation: avoid combinatorial explosion
guillembartrina Nov 5, 2023
43451f4
Add synthetic test
guillembartrina Nov 5, 2023
0edb4a3
Merge branch 'improved-negation' of github.com:guillembartrina/carac …
guillembartrina Nov 7, 2023
dea321d
Quick fix: Prohibit negated variables from being guarded by aggregate…
guillembartrina Nov 12, 2023
61a817b
Merge branch 'main' into improved-negation
guillembartrina Nov 12, 2023
cdbd3ae
Simplify negation suboperations
guillembartrina Nov 12, 2023
9c4313f
Merge branch 'improved-negation' into builtin-constraints
guillembartrina Nov 12, 2023
5d4b52c
Remove metals files
guillembartrina Nov 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Built-in constraints
  • Loading branch information
guillembartrina committed Nov 1, 2023
commit eb24a37ca07ed2dcad9dc7fbf55a5add175404ca
73 changes: 70 additions & 3 deletions src/main/scala/datalog/dsl/DSL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def not(atom: Atom): Atom = !atom

class Atom(val rId: Int, val terms: Seq[Term], val negated: Boolean) {
def unary_! : Atom = ???
def :- (body: Atom*): Unit = ???
def :- (body: (Atom | Constraint)*): Unit = ???
def :- (body: Unit): Unit = ???
val hash: String = s"${if (negated) "!" else ""}$rId.${terms.mkString("", "", "")}"
}
Expand All @@ -40,7 +40,7 @@ case class Relation[T <: Constant](id: Int, name: String)(using ee: ExecutionEng
) extends Atom(id, terms, negated) { // extend Atom so :- can accept atom of any Relation
override def unary_! : Atom = copy(negated = !negated)
// IDB tuple
override def :-(body: Atom*): Unit =
override def :-(body: (Atom | Constraint)*): Unit =
if (negated)
throw new Exception("Cannot have negated predicates in the head of a rule")
ee.insertIDB(rId, this +: body)
Expand Down Expand Up @@ -69,7 +69,7 @@ enum AggOp(val t: Term):
case class GroupingAtom(gp: Atom, gv: Seq[Variable], ags: Seq[(AggOp, Variable)])
extends Atom(gp.rId, gv ++ ags.map(_._2), false):
// We set the relation id of the grouping predicate because the 'virtual' relation will be computed from it and also because we need it to be so for certain logic: dep in JoinIndexes, node id in DependencyGraph, etc.
override val hash: String = s"GB${gp.hash}-${gv.mkString("", "", "")}-${ags.mkString("", "", "")}"
override val hash: String = s"G#${gp.hash}-${gv.mkString("", "", "")}-${ags.mkString("", "", "")}"

object groupBy:
def apply(gp: Atom, gv: Seq[Variable], ags: (AggOp, Variable)*): GroupingAtom =
Expand All @@ -90,3 +90,70 @@ object groupBy:
if (!(aggdVars ++ gVars).subsetOf(gpVars))
throw new Exception("The aggregated variables and the grouping variables must occurr in the grouping predicate")
GroupingAtom(gp, gv, ags)


enum Comparison:
case EQ, NEQ, LT, LTE, GT, GTE

enum Expression:
case One(t: Term)
case Add(l: Expression, r: Term)
case Sub(l: Expression, r: Term)
case Mul(l: Expression, r: Term)
case Div(l: Expression, r: Term)
case Mod(l: Expression, r: Term)


case class Constraint(c: Comparison, l: Expression, r: Expression):
val hash: String = s"C|$l$c$r}"

private def checkExpression(e: Expression): Unit =
inline def isAnonVariable(t: Term): Boolean = t.isInstanceOf[Variable] && t.asInstanceOf[Variable].anon
def aux(e: Expression): Boolean = e match
case Expression.One(t) => isAnonVariable(t)
case Expression.Add(l, r) => aux(l) || isAnonVariable(r)
case Expression.Sub(l, r) => aux(l) || isAnonVariable(r)
case Expression.Mul(l, r) => aux(l) || isAnonVariable(r)
case Expression.Div(l, r) => aux(l) || isAnonVariable(r)
case Expression.Mod(l, r) => aux(l) || isAnonVariable(r)
if (aux(e))
throw new Exception("Anonymous variable ('__') not allowed in comparison atoms")

implicit def term2ExpressionOne(x: Term): Expression.One = Expression.One(x)

extension (e: Expression)
def +(t: Term): Expression.Add =
Expression.Add(e, t)
def -(t: Term): Expression.Sub =
Expression.Sub(e, t)
def *(t: Term): Expression.Mul =
Expression.Mul(e, t)
def /(t: Term): Expression.Div =
Expression.Div(e, t)
def %(t: Term): Expression.Mod =
Expression.Mod(e, t)

def |=|(o: Expression): Constraint =
checkExpression(e)
checkExpression(o)
Constraint(Comparison.EQ, e, o)
def |!=|(o: Expression): Constraint =
checkExpression(e)
checkExpression(o)
Constraint(Comparison.NEQ, e, o)
def |<|(o: Expression): Constraint =
checkExpression(e)
checkExpression(o)
Constraint(Comparison.LT, e, o)
def |<=|(o: Expression): Constraint =
checkExpression(e)
checkExpression(o)
Constraint(Comparison.LTE, e, o)
def |>|(o: Expression): Constraint =
checkExpression(e)
checkExpression(o)
Constraint(Comparison.GT, e, o)
def |>=|(o: Expression): Constraint =
checkExpression(e)
checkExpression(o)
Constraint(Comparison.GTE, e, o)
8 changes: 4 additions & 4 deletions src/main/scala/datalog/execution/ExecutionEngine.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package datalog.execution

import datalog.dsl.{Atom, Constant, Term, Variable}
import datalog.dsl.{Atom, Constant, Term, Variable, Constraint}
import datalog.storage.{RelationId, StorageManager}

import scala.collection.mutable
Expand All @@ -12,7 +12,7 @@ trait ExecutionEngine {
val prebuiltOpKeys: mutable.Map[RelationId, mutable.ArrayBuffer[JoinIndexes]]
def initRelation(rId: RelationId, name: String): Unit

def insertIDB(rId: RelationId, rule: Seq[Atom]): Unit
def insertIDB(rId: RelationId, rule: Seq[Atom | Constraint]): Unit
def insertEDB(body: Atom): Unit

def solve(rId: RelationId): Set[Seq[Term]]
Expand All @@ -28,8 +28,8 @@ trait ExecutionEngine {
*
* @param rule - Includes the head at idx 0
*/
inline def getOperatorKey(rule: Seq[Atom]): JoinIndexes =
JoinIndexes(rule, None, None)
inline def getOperatorKey(rule: Seq[Atom], constraints: Seq[Constraint]): JoinIndexes =
JoinIndexes(rule, constraints, None, None, None)

def getOperatorKeys(rId: RelationId): mutable.ArrayBuffer[JoinIndexes] = {
prebuiltOpKeys.getOrElseUpdate(rId, mutable.ArrayBuffer[JoinIndexes]())
Expand Down
199 changes: 184 additions & 15 deletions src/main/scala/datalog/execution/JoinIndexes.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package datalog.execution

import datalog.dsl.{Atom, Constant, Variable, GroupingAtom, AggOp}
import datalog.dsl.{Atom, Constant, Variable, Term, GroupingAtom, AggOp, Comparison, Expression, Constraint}
import datalog.execution.ir.{IROp, ProjectJoinFilterOp, ScanOp}
import datalog.storage.{DB, EDB, NS, RelationId, StorageManager, StorageAggOp}
import datalog.storage.{DB, EDB, NS, RelationId, StorageManager, StorageAggOp, StorageComparison, StorageExpression, getType, buildComparison}
import datalog.tools.Debug.debug

import scala.collection.mutable
Expand Down Expand Up @@ -43,6 +43,8 @@ case class JoinIndexes(varIndexes: Seq[Seq[Int]],
deps: Seq[(PredicateType, RelationId)],
atoms: Seq[Atom],
cxns: mutable.Map[String, mutable.Map[Int, Seq[String]]],
cons: Seq[(Option[Boolean], StorageComparison, StorageExpression, StorageExpression, Int)],
constraints: Seq[Constraint],
edb: Boolean = false,
groupingIndexes: Map[String, GroupingJoinIndexes] = Map.empty
) {
Expand All @@ -54,6 +56,7 @@ case class JoinIndexes(varIndexes: Seq[Seq[Int]],
", deps:" + depsToString(ns) +
", edb:" + edb +
", cxn: " + cxnsToString(ns) +
", cons: " + consToString() +
" }"

def varToString(): String = varIndexes.map(v => v.mkString("$", "==$", "")).mkString("[", ",", "]")
Expand All @@ -66,12 +69,18 @@ case class JoinIndexes(varIndexes: Seq[Seq[Int]],
inCommon.map((count, hashs) =>
count.toString + ": " + hashs.map(h => ns.hashToAtom(h)).mkString("", "|", "")
).mkString("", ", ", "")} }").mkString("[", ",\n", "]")
val hash: String = atoms.map(a => a.hash).mkString("", "", "")
def consToString(): String = cons.map((o, sc, a, b, _) => s"$o#$sc($a,$b)").mkString("{", ", ", "}")
val hash: String = atoms.map(a => a.hash).mkString("", "", "") + constraints.map(a => a.hash).mkString("", "", "")


val pos2Term: Int => Term = atoms.tail.flatMap(_.terms).apply
}

object JoinIndexes {
def apply(rule: Seq[Atom], precalculatedCxns: Option[mutable.Map[String, mutable.Map[Int, Seq[String]]]],
precalculatedGroupingIndexes: Option[Map[String, GroupingJoinIndexes]]) = {
def apply(rule: Seq[Atom], constraints: Seq[Constraint],
precalculatedCxns: Option[mutable.Map[String, mutable.Map[Int, Seq[String]]]],
consHint: Option[(Seq[(Option[Boolean], StorageComparison, StorageExpression, StorageExpression, Int)], Int => Term)],
precalculatedGroupingIndexes: Option[Map[String, GroupingJoinIndexes]]) = {
val constants = mutable.Map[Int, Constant]() // position => constant
val variables = mutable.Map[Variable, Int]() // v.oid => position

Expand Down Expand Up @@ -166,7 +175,39 @@ object JoinIndexes {
).toMap
)

new JoinIndexes(bodyVars, constants.to(mutable.Map), projects, deps, rule, cxns, edb = false, groupingIndexes = groupingIndexes)

val cons = consHint.map((c, m) => c.map((o, c, l, r, _) =>
val fl = fixExpression(l, variables.apply, m)
val fr = fixExpression(r, variables.apply, m)
(o, c, fl, fr, (maxIndex(fl) ++ maxIndex(fr)).reduceOption(Math.max(_, _)).getOrElse(0)))
)
.getOrElse(
constraints.map(con =>
checkExpression(con.l, variables.keySet.toSet)
checkExpression(con.r, variables.keySet.toSet)
val sl = translateExpression(simplifyExpression(con.l), variables.apply)
val sr = translateExpression(simplifyExpression(con.r), variables.apply)
val sc = con.c match
case Comparison.EQ => StorageComparison.EQ
case Comparison.NEQ => StorageComparison.NEQ
case Comparison.LT => StorageComparison.LT
case Comparison.LTE => StorageComparison.LTE
case Comparison.GT => StorageComparison.GT
case Comparison.GTE => StorageComparison.GTE
(
(sl, sr) match
case (StorageExpression.One(c1: Constant), StorageExpression.One(c2: Constant)) => Some(buildComparison(sc, getType(c1))(c1, c2))
case _ => None
,
sc,
sl,
sr,
(maxIndex(sl) ++ maxIndex(sr)).reduceOption(Math.max(_, _)).getOrElse(0)
)
)
)

new JoinIndexes(bodyVars, constants.to(mutable.Map), projects, deps, rule, cxns, cons, constraints, edb = false, groupingIndexes = groupingIndexes)
}

// used to approximate poor user-defined order
Expand Down Expand Up @@ -197,7 +238,7 @@ object JoinIndexes {
nextOpt = None

val newAtoms = originalK.atoms.head +: newBody.map(_._1)
val newHash = JoinIndexes.getRuleHash(newAtoms)
val newHash = JoinIndexes.getRuleHash(newAtoms, originalK.constraints)

// println(s"\tOrder: ${newBody.map((a, _) => s"${sm.ns(a.rId)}:|${sortBy(a)}|").mkString("", ", ", "")}")
// if (originalK.atoms.length > 3)
Expand Down Expand Up @@ -242,7 +283,7 @@ object JoinIndexes {
// println(s"\t\t\t==>next cxn to add: ${nextOpt.map(next => sm.ns.hashToAtom(next._1.hash)).getOrElse("None")}")

val newAtoms = originalK.atoms.head +: newBody.map(_._1)
val newHash = JoinIndexes.getRuleHash(newAtoms)
val newHash = JoinIndexes.getRuleHash(newAtoms, originalK.constraints)

// if (originalK.atoms.length > 3)
// print(s"Rule: ${sm.printer.ruleToString(originalK.atoms)} => ")
Expand All @@ -261,7 +302,7 @@ object JoinIndexes {
presortSelect(sortBy, originalK, sm, -1)
val newK = sm.allRulesAllIndexes(rId).getOrElseUpdate(
newHash,
JoinIndexes(originalK.atoms.head +: newBody.map(_._1), Some(originalK.cxns), Some(originalK.groupingIndexes))
JoinIndexes(originalK.atoms.head +: newBody.map(_._1), originalK.constraints, Some(originalK.cxns), Some((originalK.cons, originalK.pos2Term)), Some(originalK.groupingIndexes))
)
(input.map(c => ProjectJoinFilterOp(rId, newK, newBody.map((_, oldP) => c.childrenSO(oldP)): _*)), newK)
}
Expand All @@ -282,18 +323,146 @@ object JoinIndexes {
presortSelect(sortBy, originalK, sm, deltaIdx)
val newK = sm.allRulesAllIndexes(rId).getOrElseUpdate(
newHash,
JoinIndexes(originalK.atoms.head +: newBody.map(_._1), Some(originalK.cxns), Some(originalK.groupingIndexes))
JoinIndexes(originalK.atoms.head +: newBody.map(_._1), originalK.constraints, Some(originalK.cxns), Some((originalK.cons, originalK.pos2Term)), Some(originalK.groupingIndexes))
)
(newK.atoms.drop(1).map(a => input(originalK.atoms.drop(1).indexOf(a))), newK)
}

def allOrders(rule: Seq[Atom]): AllIndexes = {
val idx = JoinIndexes(rule, None, None)
mutable.Map[String, JoinIndexes](rule.drop(1).permutations.map(r =>
val toRet = JoinIndexes(rule.head +: r, Some(idx.cxns), Some(idx.groupingIndexes))
def allOrders(rule: Seq[Atom], constraints: Seq[Constraint]): AllIndexes = {
val idx = JoinIndexes(rule, constraints, None, None, None)
mutable.Map[String, JoinIndexes](idx.atoms.drop(1).permutations.map(r =>
val toRet = JoinIndexes(rule.head +: r, idx.constraints, Some(idx.cxns), Some(idx.cons, idx.pos2Term), Some(idx.groupingIndexes))
toRet.hash -> toRet
).toSeq:_*)
}

def getRuleHash(rule: Seq[Atom]): String = rule.map(r => r.hash).mkString("", "", "")
def getRuleHash(rule: Seq[Atom], constraints: Seq[Constraint]): String = rule.map(r => r.hash).mkString("", "", "") + constraints.map(a => a.hash).mkString("", "", "")
}

// ---

private def simplifyExpression(e: Expression): Expression =
enum ReduceOP:
case ADD, SUB, MUL, DIV, MOD
def reduceConstants(c1: Constant, c2: Constant, rop: ReduceOP): Constant = (c1, c2) match
case (i1: Int, i2: Int) => rop match
case ReduceOP.ADD => i1 + i2
case ReduceOP.SUB => i1 - i2
case ReduceOP.MUL => i1 * i2
case ReduceOP.DIV => i1 / i2
case ReduceOP.MOD => i1 % i2
case (i1: String, i2: String) => rop match
case ReduceOP.ADD => i1 + i2
case _ => ???
case _ => ???
import Expression.*
e match
case One(t) => e
case Add(l, r) => simplifyExpression(l) match
case ne @ One(t) => (t, r) match
case (c1: Constant, c2: Constant) => One(reduceConstants(c1, c2, ReduceOP.ADD))
case (c1: Constant, c2: Variable) => Add(One(c2), c1)
case _ => Add(ne, r)
case ne @ Add(l2, r2) => (r2, r) match
case (c1: Constant, c2: Constant) => Add(l2, reduceConstants(c1, c2, ReduceOP.ADD))
case (c1: Constant, c2: Variable) => Add(Add(l2, c2), c1)
case _ => Add(ne, r)
case ne @ Sub(l2, r2) => (r2, r) match
case (c1: Constant, c2: Constant) => Add(l2, reduceConstants(c2, c1, ReduceOP.SUB))
case (c1: Constant, c2: Variable) => Sub(Add(l2, c2), c1)
case _ => Add(ne, r)
case ne => Add(ne, r)
case Sub(l, r) => simplifyExpression(l) match
case ne @ One(t) => (t, r) match
case (c1: Constant, c2: Constant) => One(reduceConstants(c1, c2, ReduceOP.SUB))
case _ => Sub(ne, r) // Case (constant, variable) could be simplified
case ne @ Add(l2, r2) => (r2, r) match
case (c1: Constant, c2: Constant) => Add(l2, reduceConstants(c1, c2, ReduceOP.SUB))
case (c1: Constant, c2: Variable) => Add(Sub(l2, c2), c1)
case _ => Sub(ne, r)
case ne @ Sub(l2, r2) => (r2, r) match
case (c1: Constant, c2: Constant) => Sub(l2, reduceConstants(c1, c2, ReduceOP.ADD))
case (c1: Constant, c2: Variable) => Sub(Sub(l2, c2), c1)
case _ => Sub(ne, r)
case ne => Sub(ne, r)
case Mul(l, r) => simplifyExpression(l) match
case ne @ One(t) => (t, r) match
case (c1: Constant, c2: Constant) => One(reduceConstants(c1, c2, ReduceOP.MUL))
case (c1: Constant, c2: Variable) => Mul(One(c2), c1)
case _ => Mul(ne, r)
case ne @ Mul(l2, r2) => (r2, r) match
case (c1: Constant, c2: Constant) => Mul(l2, reduceConstants(c1, c2, ReduceOP.MUL))
case (c1: Constant, c2: Variable) => Mul(Mul(l2, c2), c1)
case _ => Mul(ne, r)
case ne @ Div(l2, r2) => (r2, r) match
case (c1: Constant, c2: Variable) => Div(Mul(l2, c2), c1)
case _ => Mul(ne, r)
case ne => Mul(ne, r)
case Div(l, r) => simplifyExpression(l) match
case ne @ One(t) => (t, r) match
case (c1: Constant, c2: Constant) => One(reduceConstants(c1, c2, ReduceOP.DIV))
case _ => Div(ne, r) // Case (constant, variable) could be simplified
case ne @ Mul(l2, r2) => (r2, r) match
case (c1: Constant, c2: Variable) => Mul(Div(l2, c2), c1)
case _ => Div(ne, r)
case ne @ Div(l2, r2) => (r2, r) match
case (c1: Constant, c2: Constant) => Div(l2, reduceConstants(c1, c2, ReduceOP.MUL))
case (c1: Constant, c2: Variable) => Div(Div(l2, c2), c1)
case _ => Div(ne, r)
case ne => Div(ne, r)
case Mod(l, r) => simplifyExpression(l) match
case ne @ One(t) => (t, r) match
case (c1: Constant, c2: Constant) => One(reduceConstants(c1, c2, ReduceOP.MOD))
case _ => Mod(ne, r)
case ne => Mod(ne, r)

private def checkExpression(e: Expression, vars: Set[Variable]): Unit =
def checkTerm(t: Term): Unit = t match
case v: Variable =>
if (!vars.contains(v))
throw new Exception(s"Variable with varId ${v.oid} appears only in comparison atoms")
case _ => ()
import Expression.*
e match
case One(t) => checkTerm(t)
case Add(l, r) => checkExpression(l, vars); checkTerm(r)
case Sub(l, r) => checkExpression(l, vars); checkTerm(r)
case Mul(l, r) => checkExpression(l, vars); checkTerm(r)
case Div(l, r) => checkExpression(l, vars); checkTerm(r)
case Mod(l, r) => checkExpression(l, vars); checkTerm(r)


private def translateExpression(e: Expression, m: Variable => Int): StorageExpression =
import Expression as E
import StorageExpression as SE
def translateTerm(t: Term): Either[Constant, Int] = t match
case c: Constant => Left(c)
case v: Variable => Right(m(v))
e match
case E.One(t) => SE.One(translateTerm(t))
case E.Add(l, r) => SE.Add(translateExpression(l, m), translateTerm(r))
case E.Sub(l, r) => SE.Sub(translateExpression(l, m), translateTerm(r))
case E.Mul(l, r) => SE.Mul(translateExpression(l, m), translateTerm(r))
case E.Div(l, r) => SE.Div(translateExpression(l, m), translateTerm(r))
case E.Mod(l, r) => SE.Mod(translateExpression(l, m), translateTerm(r))

private def fixExpression(se: StorageExpression, m: Variable => Int, rm: Int => Term): StorageExpression =
import StorageExpression.*
se match
case One(t) => One(t.map(x => m(rm(x).asInstanceOf[Variable])))
case Add(l, r) => Add(fixExpression(l, m, rm), r.map(x => m(rm(x).asInstanceOf[Variable])))
case Sub(l, r) => Sub(fixExpression(l, m, rm), r.map(x => m(rm(x).asInstanceOf[Variable])))
case Mul(l, r) => Mul(fixExpression(l, m, rm), r.map(x => m(rm(x).asInstanceOf[Variable])))
case Div(l, r) => Div(fixExpression(l, m, rm), r.map(x => m(rm(x).asInstanceOf[Variable])))
case Mod(l, r) => Mod(fixExpression(l, m, rm), r.map(x => m(rm(x).asInstanceOf[Variable])))

private def maxIndex(se: StorageExpression): Option[Int] =
import StorageExpression.*
se match
case One(t) => t.toOption
case Add(l, r) => (maxIndex(l) ++ r.toOption).reduceOption(Math.max(_, _))
case Sub(l, r) => (maxIndex(l) ++ r.toOption).reduceOption(Math.max(_, _))
case Mul(l, r) => (maxIndex(l) ++ r.toOption).reduceOption(Math.max(_, _))
case Div(l, r) => (maxIndex(l) ++ r.toOption).reduceOption(Math.max(_, _))
case Mod(l, r) => (maxIndex(l) ++ r.toOption).reduceOption(Math.max(_, _))

Loading