diff --git a/src/main/scala/datalog/dsl/DSL.scala b/src/main/scala/datalog/dsl/DSL.scala index 56ab9ae5..6f4469ea 100644 --- a/src/main/scala/datalog/dsl/DSL.scala +++ b/src/main/scala/datalog/dsl/DSL.scala @@ -26,7 +26,7 @@ def not(atom: Atom): Atom = !atom class Atom(val rId: Int, val terms: Seq[Term], val negated: Boolean) { def unary_! : Atom = ??? - def :- (body: Atom*): Unit = ??? + def :- (body: (Atom | Constraint)*): Unit = ??? def :- (body: Unit): Unit = ??? val hash: String = s"${if (negated) "!" else ""}$rId.${terms.mkString("", "", "")}" } @@ -40,7 +40,7 @@ case class Relation[T <: Constant](id: Int, name: String)(using ee: ExecutionEng ) extends Atom(id, terms, negated) { // extend Atom so :- can accept atom of any Relation override def unary_! : Atom = copy(negated = !negated) // IDB tuple - override def :-(body: Atom*): Unit = + override def :-(body: (Atom | Constraint)*): Unit = if (negated) throw new Exception("Cannot have negated predicates in the head of a rule") ee.insertIDB(rId, this +: body) @@ -69,7 +69,7 @@ enum AggOp(val t: Term): case class GroupingAtom(gp: Atom, gv: Seq[Variable], ags: Seq[(AggOp, Variable)]) extends Atom(gp.rId, gv ++ ags.map(_._2), false): // We set the relation id of the grouping predicate because the 'virtual' relation will be computed from it and also because we need it to be so for certain logic: dep in JoinIndexes, node id in DependencyGraph, etc. - override val hash: String = s"GB${gp.hash}-${gv.mkString("", "", "")}-${ags.mkString("", "", "")}" + override val hash: String = s"G#${gp.hash}-${gv.mkString("", "", "")}-${ags.mkString("", "", "")}" object groupBy: def apply(gp: Atom, gv: Seq[Variable], ags: (AggOp, Variable)*): GroupingAtom = @@ -90,3 +90,70 @@ object groupBy: if (!(aggdVars ++ gVars).subsetOf(gpVars)) throw new Exception("The aggregated variables and the grouping variables must occurr in the grouping predicate") GroupingAtom(gp, gv, ags) + + +enum Comparison: + case EQ, NEQ, LT, LTE, GT, GTE + +enum Expression: + case One(t: Term) + case Add(l: Expression, r: Term) + case Sub(l: Expression, r: Term) + case Mul(l: Expression, r: Term) + case Div(l: Expression, r: Term) + case Mod(l: Expression, r: Term) + + +case class Constraint(c: Comparison, l: Expression, r: Expression): + val hash: String = s"C|$l$c$r}" + +private def checkExpression(e: Expression): Unit = + inline def isAnonVariable(t: Term): Boolean = t.isInstanceOf[Variable] && t.asInstanceOf[Variable].anon + def aux(e: Expression): Boolean = e match + case Expression.One(t) => isAnonVariable(t) + case Expression.Add(l, r) => aux(l) || isAnonVariable(r) + case Expression.Sub(l, r) => aux(l) || isAnonVariable(r) + case Expression.Mul(l, r) => aux(l) || isAnonVariable(r) + case Expression.Div(l, r) => aux(l) || isAnonVariable(r) + case Expression.Mod(l, r) => aux(l) || isAnonVariable(r) + if (aux(e)) + throw new Exception("Anonymous variable ('__') not allowed in comparison atoms") + +implicit def term2ExpressionOne(x: Term): Expression.One = Expression.One(x) + +extension (e: Expression) + def +(t: Term): Expression.Add = + Expression.Add(e, t) + def -(t: Term): Expression.Sub = + Expression.Sub(e, t) + def *(t: Term): Expression.Mul = + Expression.Mul(e, t) + def /(t: Term): Expression.Div = + Expression.Div(e, t) + def %(t: Term): Expression.Mod = + Expression.Mod(e, t) + + def |=|(o: Expression): Constraint = + checkExpression(e) + checkExpression(o) + Constraint(Comparison.EQ, e, o) + def |!=|(o: Expression): Constraint = + checkExpression(e) + checkExpression(o) + Constraint(Comparison.NEQ, e, o) + def |<|(o: Expression): Constraint = + checkExpression(e) + checkExpression(o) + Constraint(Comparison.LT, e, o) + def |<=|(o: Expression): Constraint = + checkExpression(e) + checkExpression(o) + Constraint(Comparison.LTE, e, o) + def |>|(o: Expression): Constraint = + checkExpression(e) + checkExpression(o) + Constraint(Comparison.GT, e, o) + def |>=|(o: Expression): Constraint = + checkExpression(e) + checkExpression(o) + Constraint(Comparison.GTE, e, o) diff --git a/src/main/scala/datalog/execution/BytecodeCompiler.scala b/src/main/scala/datalog/execution/BytecodeCompiler.scala index ca7531c5..20c062bc 100644 --- a/src/main/scala/datalog/execution/BytecodeCompiler.scala +++ b/src/main/scala/datalog/execution/BytecodeCompiler.scala @@ -102,10 +102,17 @@ class BytecodeCompiler(val storageManager: StorageManager)(using JITOptions) ext .constantInstruction(rId) emitSMCall(xb, meth, classOf[Int]) - case ComplementOp(arity) => + case NegationOp(child, cols) => + val tmp = cols.map(_.exists(_.isEmpty)) xb.aload(0) - .constantInstruction(arity) - emitSMCall(xb, "getComplement", classOf[Int]) + xb.aload(0) + emitCols(xb, cols) + emitSMCall(xb, "getGroundOf", classOf[Seq[?]]) + xb.aload(0) + traverse(xb, child) + emitSeq(xb, tmp.map(v => xxb => emitBoolean(xxb, v))) + emitSMCall(xb, "zeroOut", classOf[EDB], classOf[Seq[?]]) + emitSMCall(xb, "diff", classOf[EDB], classOf[EDB]) case ScanEDBOp(rId) => xb.aload(0) diff --git a/src/main/scala/datalog/execution/BytecodeGenerator.scala b/src/main/scala/datalog/execution/BytecodeGenerator.scala index 2cf41537..549c5419 100644 --- a/src/main/scala/datalog/execution/BytecodeGenerator.scala +++ b/src/main/scala/datalog/execution/BytecodeGenerator.scala @@ -179,6 +179,12 @@ object BytecodeGenerator { else xb.constantInstruction(0) + /** Emit `Boolean.valueOf($value)`. */ + def emitBoolean(xb: CodeBuilder, value: Boolean): Unit = + xb.constantInstruction(if value then 1 else 0) + .invokestatic(clsDesc(classOf[java.lang.Boolean]), "valueOf", + MethodTypeDesc.of(clsDesc(classOf[java.lang.Boolean]), clsDesc(classOf[Boolean]))) + def emitSeqInt(xb: CodeBuilder, value: Seq[Int]): Unit = emitSeq(xb, value.map(v => xxb => emitInteger(xxb, v))) @@ -248,6 +254,17 @@ object BytecodeGenerator { } } + def emitEither[A, B](xb: CodeBuilder, either: Either[A, B], emitA: (CodeBuilder, A) => Unit, emitB: (CodeBuilder, B) => Unit): Unit = + either match + case Left(value) => + emitNew(xb, classOf[Left[A, B]], { xxb => + emitA(xxb, value) + }) + case Right(value) => + emitNew(xb, classOf[Right[A, B]], { xxb => + emitB(xxb, value) + }) + def emitProjIndexes(xb: CodeBuilder, value: Seq[(String, Constant)]): Unit = emitSeq(xb, value.map(v => xxb => emitStringConstantTuple2(xxb, v))) @@ -268,6 +285,7 @@ object BytecodeGenerator { def emitCxns(xb: CodeBuilder, value: collection.mutable.Map[String, collection.mutable.Map[Int, Seq[String]]]): Unit = emitMap(xb, value.toSeq, emitString, emitCxnElement) + /* def emitJoinIndexes(xb: CodeBuilder, value: JoinIndexes): Unit = emitNew(xb, classOf[JoinIndexes], xxb => emitVarIndexes(xxb, value.varIndexes) @@ -277,7 +295,11 @@ object BytecodeGenerator { // emitArrayAtoms(xxb, value.atoms) emitSeq(xb, value.atoms.map(a => xxb => emitAtom(xxb, a))) emitCxns(xxb, value.cxns) - emitBool(xxb, value.edb)) + // TODO: Missing negationInfo! + emitBool(xxb, value.edb), + // TODO: Missing groupingInfos! + ) + */ def emitStorageAggOp(xb: CodeBuilder, sao: StorageAggOp): Unit = val enumCompanionCls = classOf[StorageAggOp.type] @@ -315,6 +337,18 @@ object BytecodeGenerator { emitSeqInt(xxb, value.groupingIndexes) emitAggOpInfos(xxb, value.aggOpInfos)) + def emitCols(xb: CodeBuilder, value: Seq[Either[Constant, Seq[(RelationId, Int)]]]): Unit = + emitSeq(xb, value.map(v => xxb => + emitEither(xxb, v, emitConstant, (xxxb, s) => + emitSeq(xxxb, s.map(vv => xxxxb => + emitNew(xxxxb, classOf[(Int, Int)], xxxxxb => + emitInteger(xxxxxb, vv._1) + emitInteger(xxxxxb, vv._2) + ) + )) + ) + )) + val CD_BoxedUnit = clsDesc(classOf[scala.runtime.BoxedUnit]) /** Emit `BoxedUnit.UNIT`. */ diff --git a/src/main/scala/datalog/execution/ExecutionEngine.scala b/src/main/scala/datalog/execution/ExecutionEngine.scala index 1d384099..5a01d314 100644 --- a/src/main/scala/datalog/execution/ExecutionEngine.scala +++ b/src/main/scala/datalog/execution/ExecutionEngine.scala @@ -1,6 +1,6 @@ package datalog.execution -import datalog.dsl.{Atom, Constant, Term, Variable} +import datalog.dsl.{Atom, Constant, Term, Variable, Constraint} import datalog.storage.{RelationId, StorageManager} import scala.collection.mutable @@ -12,7 +12,7 @@ trait ExecutionEngine { val prebuiltOpKeys: mutable.Map[RelationId, mutable.ArrayBuffer[JoinIndexes]] def initRelation(rId: RelationId, name: String): Unit - def insertIDB(rId: RelationId, rule: Seq[Atom]): Unit + def insertIDB(rId: RelationId, rule: Seq[Atom | Constraint]): Unit def insertEDB(body: Atom): Unit def solve(rId: RelationId): Set[Seq[Term]] @@ -28,8 +28,8 @@ trait ExecutionEngine { * * @param rule - Includes the head at idx 0 */ - inline def getOperatorKey(rule: Seq[Atom]): JoinIndexes = - JoinIndexes(rule, None, None) + inline def getOperatorKey(rule: Seq[Atom], constraints: Seq[Constraint]): JoinIndexes = + JoinIndexes(rule, constraints, None, None, None) def getOperatorKeys(rId: RelationId): mutable.ArrayBuffer[JoinIndexes] = { prebuiltOpKeys.getOrElseUpdate(rId, mutable.ArrayBuffer[JoinIndexes]()) diff --git a/src/main/scala/datalog/execution/JoinIndexes.scala b/src/main/scala/datalog/execution/JoinIndexes.scala index ab5bcb47..15211f17 100644 --- a/src/main/scala/datalog/execution/JoinIndexes.scala +++ b/src/main/scala/datalog/execution/JoinIndexes.scala @@ -1,8 +1,8 @@ package datalog.execution -import datalog.dsl.{Atom, Constant, Variable, GroupingAtom, AggOp} +import datalog.dsl.{Atom, Constant, Variable, Term, GroupingAtom, AggOp, Comparison, Expression, Constraint} import datalog.execution.ir.{IROp, ProjectJoinFilterOp, ScanOp} -import datalog.storage.{DB, EDB, NS, RelationId, StorageManager, StorageAggOp} +import datalog.storage.{DB, EDB, NS, RelationId, StorageManager, StorageAggOp, StorageComparison, StorageExpression, getType, comparisons} import datalog.tools.Debug.debug import scala.collection.mutable @@ -36,6 +36,7 @@ case class GroupingJoinIndexes(varIndexes: Seq[Seq[Int]], * @param edb - for rules that have EDBs defined on the same predicate, just read * @param atoms - the original atoms from the DSL * @param cxns - convenience data structure tracking how many variables in common each atom has with every other atom. + * @param negationInfo - information needed to build the complement relation of negated atoms: for each term, either a constant or a list of pairs (relationid, column) of the ocurrences of the variable in the rule (empty for anonynous variable) */ case class JoinIndexes(varIndexes: Seq[Seq[Int]], constIndexes: mutable.Map[Int, Constant], @@ -43,6 +44,9 @@ case class JoinIndexes(varIndexes: Seq[Seq[Int]], deps: Seq[(PredicateType, RelationId)], atoms: Seq[Atom], cxns: mutable.Map[String, mutable.Map[Int, Seq[String]]], + cons: Seq[(Option[Boolean], StorageComparison, StorageExpression, StorageExpression, Int)], + constraints: Seq[Constraint], + negationInfo: Map[String, Seq[Either[Constant, Seq[(RelationId, Int)]]]], edb: Boolean = false, groupingIndexes: Map[String, GroupingJoinIndexes] = Map.empty ) { @@ -54,6 +58,8 @@ case class JoinIndexes(varIndexes: Seq[Seq[Int]], ", deps:" + depsToString(ns) + ", edb:" + edb + ", cxn: " + cxnsToString(ns) + + ", cons: " + consToString() + + ", negation: " + negationToString(ns) + " }" def varToString(): String = varIndexes.map(v => v.mkString("$", "==$", "")).mkString("[", ",", "]") @@ -66,12 +72,24 @@ case class JoinIndexes(varIndexes: Seq[Seq[Int]], inCommon.map((count, hashs) => count.toString + ": " + hashs.map(h => ns.hashToAtom(h)).mkString("", "|", "") ).mkString("", ", ", "")} }").mkString("[", ",\n", "]") - val hash: String = atoms.map(a => a.hash).mkString("", "", "") + def consToString(): String = cons.map((o, sc, a, b, _) => s"$o#$sc($a,$b)").mkString("{", ", ", "}") + def negationToString(ns: NS): String = + negationInfo.map((h, infos) => + s"{ ${ns.hashToAtom(h)} => ${ + infos.map{ + case Left(value) => value + case Right(value) => s"[ ${value.map((r, c) => s"(${ns(r)}, $c)")} ]" + }} }").mkString("[", ",\n", "]") + val hash: String = atoms.map(a => a.hash).mkString("", "", "") + constraints.map(a => a.hash).mkString("", "", "") + + val pos2Term: Int => Term = atoms.tail.flatMap(_.terms).apply } object JoinIndexes { - def apply(rule: Seq[Atom], precalculatedCxns: Option[mutable.Map[String, mutable.Map[Int, Seq[String]]]], - precalculatedGroupingIndexes: Option[Map[String, GroupingJoinIndexes]]) = { + def apply(rule: Seq[Atom], constraints: Seq[Constraint], + precalculatedCxns: Option[mutable.Map[String, mutable.Map[Int, Seq[String]]]], + consHint: Option[(Seq[(Option[Boolean], StorageComparison, StorageExpression, StorageExpression, Int)], Int => Term)], + precalculatedGroupingIndexes: Option[Map[String, GroupingJoinIndexes]]) = { val constants = mutable.Map[Int, Constant]() // position => constant val variables = mutable.Map[Variable, Int]() // v.oid => position @@ -83,23 +101,19 @@ object JoinIndexes { case _ => if (a.negated) PredicateType.NEGATED else PredicateType.POSITIVE , a.rId)) - val typeHelper = body.flatMap(a => a.terms.map(* => !a.negated)) - val bodyVars = body - .flatMap(a => a.terms) // all terms in one seq + .flatMap(a => a.terms.zipWithIndex.map((t, i) => (t, (a.negated, a.isInstanceOf[GroupingAtom] && i >= a.asInstanceOf[GroupingAtom].gv.length)))) // all terms in one seq .zipWithIndex // term, position - .groupBy(z => z._1) // group by term + .groupBy(z => z._1._1) // group by term .filter((term, matches) => // matches = Seq[(var, pos1), (var, pos2), ...] term match { case v: Variable => - matches.map(_._2).find(typeHelper) match - case Some(pos) => - variables(v) = pos - case None => - if (v.oid != -1) - throw new Exception(s"Variable with varId ${v.oid} appears only in negated rules") - else - () + val wrong = v.oid != -1 && matches.exists(_._1._2._1) && matches.forall(x => x._1._2._1 || x._1._2._2) // Var occurs negated and all occurrences are either negated or aggregated + if wrong then + throw new Exception(s"Variable with varId ${v.oid} appears only in negated atoms (and possibly in aggregated positions of grouping atoms)") + else + if (v.oid != -1) + variables(v) = matches.find(!_._1._2._1).get._2 !v.anon && matches.length >= 2 case c: Constant => matches.foreach((_, idx) => constants(idx) = c) @@ -137,6 +151,18 @@ object JoinIndexes { )).to(mutable.Map) ) + + val variables2 = body.filterNot(_.negated).flatMap(a => + a.terms.zipWithIndex.collect{ case (v: Variable, i) if !v.anon => (v, i) }.map((v, i) => (v, (a.rId, i))) + ).groupBy(_._1).view.mapValues(_.map(_._2)) + + val negationInfo = body.filter(_.negated).map(a => + a.hash -> a.terms.map{ + case c: Constant => Left(c) + case v: Variable => Right(if v.anon then Seq() else variables2(v)) + } + ).toMap + //groupings val groupingIndexes = precalculatedGroupingIndexes.getOrElse( body.collect{ case ga: GroupingAtom => ga }.map(ga => @@ -166,7 +192,39 @@ object JoinIndexes { ).toMap ) - new JoinIndexes(bodyVars, constants.to(mutable.Map), projects, deps, rule, cxns, edb = false, groupingIndexes = groupingIndexes) + + val cons = consHint.map((c, m) => c.map((o, c, l, r, _) => + val fl = fixExpression(l, variables.apply, m) + val fr = fixExpression(r, variables.apply, m) + (o, c, fl, fr, (maxIndex(fl) ++ maxIndex(fr)).reduceOption(Math.max(_, _)).getOrElse(0))) + ) + .getOrElse( + constraints.map(con => + checkExpression(con.l, variables.keySet.toSet) + checkExpression(con.r, variables.keySet.toSet) + val sl = translateExpression(simplifyExpression(con.l), variables.apply) + val sr = translateExpression(simplifyExpression(con.r), variables.apply) + val sc = con.c match + case Comparison.EQ => StorageComparison.EQ + case Comparison.NEQ => StorageComparison.NEQ + case Comparison.LT => StorageComparison.LT + case Comparison.LTE => StorageComparison.LTE + case Comparison.GT => StorageComparison.GT + case Comparison.GTE => StorageComparison.GTE + ( + (sl, sr) match + case (StorageExpression.One(c1: Constant), StorageExpression.One(c2: Constant)) => Some(comparisons(sc)(getType(c1))(c1, c2)) + case _ => None + , + sc, + sl, + sr, + (maxIndex(sl) ++ maxIndex(sr)).reduceOption(Math.max(_, _)).getOrElse(0) + ) + ) + ) + + new JoinIndexes(bodyVars, constants.to(mutable.Map), projects, deps, rule, cxns, cons, constraints, negationInfo, edb = false, groupingIndexes = groupingIndexes) } // used to approximate poor user-defined order @@ -197,7 +255,7 @@ object JoinIndexes { nextOpt = None val newAtoms = originalK.atoms.head +: newBody.map(_._1) - val newHash = JoinIndexes.getRuleHash(newAtoms) + val newHash = JoinIndexes.getRuleHash(newAtoms, originalK.constraints) // println(s"\tOrder: ${newBody.map((a, _) => s"${sm.ns(a.rId)}:|${sortBy(a)}|").mkString("", ", ", "")}") // if (originalK.atoms.length > 3) @@ -242,7 +300,7 @@ object JoinIndexes { // println(s"\t\t\t==>next cxn to add: ${nextOpt.map(next => sm.ns.hashToAtom(next._1.hash)).getOrElse("None")}") val newAtoms = originalK.atoms.head +: newBody.map(_._1) - val newHash = JoinIndexes.getRuleHash(newAtoms) + val newHash = JoinIndexes.getRuleHash(newAtoms, originalK.constraints) // if (originalK.atoms.length > 3) // print(s"Rule: ${sm.printer.ruleToString(originalK.atoms)} => ") @@ -261,7 +319,7 @@ object JoinIndexes { presortSelect(sortBy, originalK, sm, -1) val newK = sm.allRulesAllIndexes(rId).getOrElseUpdate( newHash, - JoinIndexes(originalK.atoms.head +: newBody.map(_._1), Some(originalK.cxns), Some(originalK.groupingIndexes)) + JoinIndexes(originalK.atoms.head +: newBody.map(_._1), originalK.constraints, Some(originalK.cxns), Some((originalK.cons, originalK.pos2Term)), Some(originalK.groupingIndexes)) ) (input.map(c => ProjectJoinFilterOp(rId, newK, newBody.map((_, oldP) => c.childrenSO(oldP)): _*)), newK) } @@ -282,18 +340,146 @@ object JoinIndexes { presortSelect(sortBy, originalK, sm, deltaIdx) val newK = sm.allRulesAllIndexes(rId).getOrElseUpdate( newHash, - JoinIndexes(originalK.atoms.head +: newBody.map(_._1), Some(originalK.cxns), Some(originalK.groupingIndexes)) + JoinIndexes(originalK.atoms.head +: newBody.map(_._1), originalK.constraints, Some(originalK.cxns), Some((originalK.cons, originalK.pos2Term)), Some(originalK.groupingIndexes)) ) (newK.atoms.drop(1).map(a => input(originalK.atoms.drop(1).indexOf(a))), newK) } - def allOrders(rule: Seq[Atom]): AllIndexes = { - val idx = JoinIndexes(rule, None, None) - mutable.Map[String, JoinIndexes](rule.drop(1).permutations.map(r => - val toRet = JoinIndexes(rule.head +: r, Some(idx.cxns), Some(idx.groupingIndexes)) + def allOrders(rule: Seq[Atom], constraints: Seq[Constraint]): AllIndexes = { + val idx = JoinIndexes(rule, constraints, None, None, None) + mutable.Map[String, JoinIndexes](idx.atoms.drop(1).permutations.map(r => + val toRet = JoinIndexes(rule.head +: r, idx.constraints, Some(idx.cxns), Some(idx.cons, idx.pos2Term), Some(idx.groupingIndexes)) toRet.hash -> toRet ).toSeq:_*) } - def getRuleHash(rule: Seq[Atom]): String = rule.map(r => r.hash).mkString("", "", "") + def getRuleHash(rule: Seq[Atom], constraints: Seq[Constraint]): String = rule.map(r => r.hash).mkString("", "", "") + constraints.map(a => a.hash).mkString("", "", "") } + +// --- + +private def simplifyExpression(e: Expression): Expression = + enum ReduceOP: + case ADD, SUB, MUL, DIV, MOD + def reduceConstants(c1: Constant, c2: Constant, rop: ReduceOP): Constant = (c1, c2) match + case (i1: Int, i2: Int) => rop match + case ReduceOP.ADD => i1 + i2 + case ReduceOP.SUB => i1 - i2 + case ReduceOP.MUL => i1 * i2 + case ReduceOP.DIV => i1 / i2 + case ReduceOP.MOD => i1 % i2 + case (i1: String, i2: String) => rop match + case ReduceOP.ADD => i1 + i2 + case _ => ??? + case _ => ??? + import Expression.* + e match + case One(t) => e + case Add(l, r) => simplifyExpression(l) match + case ne @ One(t) => (t, r) match + case (c1: Constant, c2: Constant) => One(reduceConstants(c1, c2, ReduceOP.ADD)) + case (c1: Constant, c2: Variable) => Add(One(c2), c1) + case _ => Add(ne, r) + case ne @ Add(l2, r2) => (r2, r) match + case (c1: Constant, c2: Constant) => Add(l2, reduceConstants(c1, c2, ReduceOP.ADD)) + case (c1: Constant, c2: Variable) => Add(Add(l2, c2), c1) + case _ => Add(ne, r) + case ne @ Sub(l2, r2) => (r2, r) match + case (c1: Constant, c2: Constant) => Add(l2, reduceConstants(c2, c1, ReduceOP.SUB)) + case (c1: Constant, c2: Variable) => Sub(Add(l2, c2), c1) + case _ => Add(ne, r) + case ne => Add(ne, r) + case Sub(l, r) => simplifyExpression(l) match + case ne @ One(t) => (t, r) match + case (c1: Constant, c2: Constant) => One(reduceConstants(c1, c2, ReduceOP.SUB)) + case _ => Sub(ne, r) // Case (constant, variable) could be simplified + case ne @ Add(l2, r2) => (r2, r) match + case (c1: Constant, c2: Constant) => Add(l2, reduceConstants(c1, c2, ReduceOP.SUB)) + case (c1: Constant, c2: Variable) => Add(Sub(l2, c2), c1) + case _ => Sub(ne, r) + case ne @ Sub(l2, r2) => (r2, r) match + case (c1: Constant, c2: Constant) => Sub(l2, reduceConstants(c1, c2, ReduceOP.ADD)) + case (c1: Constant, c2: Variable) => Sub(Sub(l2, c2), c1) + case _ => Sub(ne, r) + case ne => Sub(ne, r) + case Mul(l, r) => simplifyExpression(l) match + case ne @ One(t) => (t, r) match + case (c1: Constant, c2: Constant) => One(reduceConstants(c1, c2, ReduceOP.MUL)) + case (c1: Constant, c2: Variable) => Mul(One(c2), c1) + case _ => Mul(ne, r) + case ne @ Mul(l2, r2) => (r2, r) match + case (c1: Constant, c2: Constant) => Mul(l2, reduceConstants(c1, c2, ReduceOP.MUL)) + case (c1: Constant, c2: Variable) => Mul(Mul(l2, c2), c1) + case _ => Mul(ne, r) + case ne @ Div(l2, r2) => (r2, r) match + case (c1: Constant, c2: Variable) => Div(Mul(l2, c2), c1) + case _ => Mul(ne, r) + case ne => Mul(ne, r) + case Div(l, r) => simplifyExpression(l) match + case ne @ One(t) => (t, r) match + case (c1: Constant, c2: Constant) => One(reduceConstants(c1, c2, ReduceOP.DIV)) + case _ => Div(ne, r) // Case (constant, variable) could be simplified + case ne @ Mul(l2, r2) => (r2, r) match + case (c1: Constant, c2: Variable) => Mul(Div(l2, c2), c1) + case _ => Div(ne, r) + case ne @ Div(l2, r2) => (r2, r) match + case (c1: Constant, c2: Constant) => Div(l2, reduceConstants(c1, c2, ReduceOP.MUL)) + case (c1: Constant, c2: Variable) => Div(Div(l2, c2), c1) + case _ => Div(ne, r) + case ne => Div(ne, r) + case Mod(l, r) => simplifyExpression(l) match + case ne @ One(t) => (t, r) match + case (c1: Constant, c2: Constant) => One(reduceConstants(c1, c2, ReduceOP.MOD)) + case _ => Mod(ne, r) + case ne => Mod(ne, r) + +private def checkExpression(e: Expression, vars: Set[Variable]): Unit = + def checkTerm(t: Term): Unit = t match + case v: Variable => + if (!vars.contains(v)) + throw new Exception(s"Variable with varId ${v.oid} appears only in comparison atoms") + case _ => () + import Expression.* + e match + case One(t) => checkTerm(t) + case Add(l, r) => checkExpression(l, vars); checkTerm(r) + case Sub(l, r) => checkExpression(l, vars); checkTerm(r) + case Mul(l, r) => checkExpression(l, vars); checkTerm(r) + case Div(l, r) => checkExpression(l, vars); checkTerm(r) + case Mod(l, r) => checkExpression(l, vars); checkTerm(r) + + +private def translateExpression(e: Expression, m: Variable => Int): StorageExpression = + import Expression as E + import StorageExpression as SE + def translateTerm(t: Term): Either[Constant, Int] = t match + case c: Constant => Left(c) + case v: Variable => Right(m(v)) + e match + case E.One(t) => SE.One(translateTerm(t)) + case E.Add(l, r) => SE.Add(translateExpression(l, m), translateTerm(r)) + case E.Sub(l, r) => SE.Sub(translateExpression(l, m), translateTerm(r)) + case E.Mul(l, r) => SE.Mul(translateExpression(l, m), translateTerm(r)) + case E.Div(l, r) => SE.Div(translateExpression(l, m), translateTerm(r)) + case E.Mod(l, r) => SE.Mod(translateExpression(l, m), translateTerm(r)) + +private def fixExpression(se: StorageExpression, m: Variable => Int, rm: Int => Term): StorageExpression = + import StorageExpression.* + se match + case One(t) => One(t.map(x => m(rm(x).asInstanceOf[Variable]))) + case Add(l, r) => Add(fixExpression(l, m, rm), r.map(x => m(rm(x).asInstanceOf[Variable]))) + case Sub(l, r) => Sub(fixExpression(l, m, rm), r.map(x => m(rm(x).asInstanceOf[Variable]))) + case Mul(l, r) => Mul(fixExpression(l, m, rm), r.map(x => m(rm(x).asInstanceOf[Variable]))) + case Div(l, r) => Div(fixExpression(l, m, rm), r.map(x => m(rm(x).asInstanceOf[Variable]))) + case Mod(l, r) => Mod(fixExpression(l, m, rm), r.map(x => m(rm(x).asInstanceOf[Variable]))) + +private def maxIndex(se: StorageExpression): Option[Int] = + import StorageExpression.* + se match + case One(t) => t.toOption + case Add(l, r) => (maxIndex(l) ++ r.toOption).reduceOption(Math.max(_, _)) + case Sub(l, r) => (maxIndex(l) ++ r.toOption).reduceOption(Math.max(_, _)) + case Mul(l, r) => (maxIndex(l) ++ r.toOption).reduceOption(Math.max(_, _)) + case Div(l, r) => (maxIndex(l) ++ r.toOption).reduceOption(Math.max(_, _)) + case Mod(l, r) => (maxIndex(l) ++ r.toOption).reduceOption(Math.max(_, _)) + \ No newline at end of file diff --git a/src/main/scala/datalog/execution/LambdaCompiler.scala b/src/main/scala/datalog/execution/LambdaCompiler.scala index 0acd0009..68ee62f9 100644 --- a/src/main/scala/datalog/execution/LambdaCompiler.scala +++ b/src/main/scala/datalog/execution/LambdaCompiler.scala @@ -148,8 +148,13 @@ class LambdaCompiler(val storageManager: StorageManager)(using JITOptions) exten } } - case ComplementOp(arity) => - _.getComplement(arity) + case NegationOp(child, cols) => + val tmp = cols.map(_.exists(_.isEmpty)) + val clh = compile(child) + sm => + val compl = sm.getGroundOf(cols) + val nq = sm.zeroOut(clh(sm), tmp) + sm.diff(compl, nq) case ScanEDBOp(rId) => if (storageManager.edbContains(rId)) diff --git a/src/main/scala/datalog/execution/NaiveExecutionEngine.scala b/src/main/scala/datalog/execution/NaiveExecutionEngine.scala index a1e135b4..b637b920 100644 --- a/src/main/scala/datalog/execution/NaiveExecutionEngine.scala +++ b/src/main/scala/datalog/execution/NaiveExecutionEngine.scala @@ -1,6 +1,6 @@ package datalog.execution -import datalog.dsl.{Atom, Constant, Term, Variable, GroupingAtom, AggOp} +import datalog.dsl.{Atom, Constant, Term, Variable, GroupingAtom, AggOp, Constraint} import datalog.storage.{RelationId, CollectionsStorageManager, StorageManager, EDB, StorageAggOp} import datalog.tools.Debug.debug @@ -34,22 +34,21 @@ class NaiveExecutionEngine(val storageManager: StorageManager, stratified: Boole get(storageManager.ns(name)) } - def insertIDB(rId: RelationId, rule: Seq[Atom]): Unit = { - precedenceGraph.addNode(rule) - idbs.getOrElseUpdate(rId, mutable.ArrayBuffer[IndexedSeq[Atom]]()).addOne(rule.toIndexedSeq) - val jIdx = getOperatorKey(rule) - prebuiltOpKeys.getOrElseUpdate(rId, mutable.ArrayBuffer[JoinIndexes]()).addOne(jIdx) - storageManager.addConstantsToDomain(jIdx.constIndexes.values.toSeq) + def insertIDB(rId: RelationId, rule: Seq[Atom | Constraint]): Unit = { + val (atoms, constraints) = rule.partitionMap{ + case a: Atom => Left(a) + case c: Constraint => Right(c) + } - // We need to add the constants occurring in the grouping predicates of the grouping atoms - rule.collect{ case ga: GroupingAtom => ga}.foreach(ga => - storageManager.addConstantsToDomain(jIdx.groupingIndexes(ga.hash).constIndexes.values.toSeq) - ) + precedenceGraph.addNode(atoms) + idbs.getOrElseUpdate(rId, mutable.ArrayBuffer[IndexedSeq[Atom]]()).addOne(atoms.toIndexedSeq) + val jIdx = getOperatorKey(atoms, constraints) + prebuiltOpKeys.getOrElseUpdate(rId, mutable.ArrayBuffer[JoinIndexes]()).addOne(jIdx) } def insertEDB(rule: Atom): Unit = { if (!storageManager.edbContains(rule.rId)) - prebuiltOpKeys.getOrElseUpdate(rule.rId, mutable.ArrayBuffer[JoinIndexes]()).addOne(JoinIndexes(IndexedSeq(), mutable.Map(), IndexedSeq(), Seq((PredicateType.POSITIVE, rule.rId)), Seq(rule), mutable.Map.empty, true)) + prebuiltOpKeys.getOrElseUpdate(rule.rId, mutable.ArrayBuffer[JoinIndexes]()).addOne(JoinIndexes(IndexedSeq(), mutable.Map(), IndexedSeq(), Seq((PredicateType.POSITIVE, rule.rId)), Seq(rule), mutable.Map.empty, Seq.empty, Seq.empty, Map.empty, true)) storageManager.insertEDB(rule) } diff --git a/src/main/scala/datalog/execution/QuoteCompiler.scala b/src/main/scala/datalog/execution/QuoteCompiler.scala index 59088c89..8fe5963c 100644 --- a/src/main/scala/datalog/execution/QuoteCompiler.scala +++ b/src/main/scala/datalog/execution/QuoteCompiler.scala @@ -1,6 +1,6 @@ package datalog.execution -import datalog.dsl.{Atom, Constant, Term, Variable} +import datalog.dsl.{Atom, Constant, Term, Variable, Comparison, Expression, Constraint} import datalog.execution.ir.* import datalog.storage.* import datalog.tools.Debug.debug @@ -62,6 +62,49 @@ class QuoteCompiler(val storageManager: StorageManager)(using JITOptions) extend } } + given ToExpr[StorageComparison] with { + def apply(x: StorageComparison)(using Quotes) = { + x match + case StorageComparison.EQ => '{ StorageComparison.EQ } + case StorageComparison.NEQ => '{ StorageComparison.NEQ } + case StorageComparison.LT => '{ StorageComparison.LT } + case StorageComparison.LTE => '{ StorageComparison.LTE } + case StorageComparison.GT => '{ StorageComparison.GT } + case StorageComparison.GTE => '{ StorageComparison.GTE } + + } + } + + given ToExpr[Expression] with { + def apply(x: Expression)(using Quotes) = { + x match + case Expression.One(t) => '{ Expression.One( ${ Expr(t) } ) } + case Expression.Add(l, r) => '{ Expression.Add( ${ Expr(l) }, ${ Expr(r) } ) } + case Expression.Sub(l, r) => '{ Expression.Sub( ${ Expr(l) }, ${ Expr(r) } ) } + case Expression.Mul(l, r) => '{ Expression.Mul( ${ Expr(l) }, ${ Expr(r) } ) } + case Expression.Div(l, r) => '{ Expression.Div( ${ Expr(l) }, ${ Expr(r) } ) } + case Expression.Mod(l, r) => '{ Expression.Mod( ${ Expr(l) }, ${ Expr(r) } ) } + } + } + + given ToExpr[StorageExpression] with { + def apply(x: StorageExpression)(using Quotes) = { + x match + case StorageExpression.One(t) => '{ StorageExpression.One( ${ Expr(t) } ) } + case StorageExpression.Add(l, r) => '{ StorageExpression.Add( ${ Expr(l) }, ${ Expr(r) } ) } + case StorageExpression.Sub(l, r) => '{ StorageExpression.Sub( ${ Expr(l) }, ${ Expr(r) } ) } + case StorageExpression.Mul(l, r) => '{ StorageExpression.Mul( ${ Expr(l) }, ${ Expr(r) } ) } + case StorageExpression.Div(l, r) => '{ StorageExpression.Div( ${ Expr(l) }, ${ Expr(r) } ) } + case StorageExpression.Mod(l, r) => '{ StorageExpression.Mod( ${ Expr(l) }, ${ Expr(r) } ) } + } + } + + given ToExpr[Constraint] with { + def apply(x: Constraint)(using Quotes) = { + '{ Constraint( Comparison.fromOrdinal( ${ Expr(x.c.ordinal) } ), ${ Expr(x.l) }, ${ Expr(x.r) } ) } + } + } + given ToExpr[JoinIndexes] with { def apply(x: JoinIndexes)(using Quotes) = { '{ @@ -72,6 +115,9 @@ class QuoteCompiler(val storageManager: StorageManager)(using JITOptions) extend ${ Expr(x.deps) }, ${ Expr(x.atoms) }, ${ Expr(x.cxns) }, + ${ Expr(x.cons) }, + ${ Expr(x.constraints) }, + ${ Expr(x.negationInfo) }, ${ Expr(x.edb) } ) } @@ -135,8 +181,14 @@ class QuoteCompiler(val storageManager: StorageManager)(using JITOptions) extend } } - case ComplementOp(arity) => - '{ $stagedSM.getComplement(${ Expr(arity) }) } + case NegationOp(child, cols) => + val tmp = cols.map(_.exists(_.isEmpty)) + val clh = compileIRRelOp(child) + '{ + val compl = $stagedSM.getGroundOf(${ Expr(cols) }) + val nq = $stagedSM.zeroOut($clh, ${ Expr(tmp) }) + $stagedSM.diff(compl, nq) + } case ScanEDBOp(rId) => if (storageManager.edbContains(rId)) diff --git a/src/main/scala/datalog/execution/StagedExecutionEngine.scala b/src/main/scala/datalog/execution/StagedExecutionEngine.scala index d8866911..76395bef 100644 --- a/src/main/scala/datalog/execution/StagedExecutionEngine.scala +++ b/src/main/scala/datalog/execution/StagedExecutionEngine.scala @@ -1,11 +1,11 @@ package datalog.execution -import datalog.dsl.{Atom, Constant, Term, Variable, GroupingAtom, AggOp} +import datalog.dsl.{Atom, Constant, Term, Variable, GroupingAtom, AggOp, Comparison, Constraint} import datalog.execution import datalog.execution.ast.* import datalog.execution.ast.transform.{ASTTransformerContext, CopyEliminationPass, Transformer} import datalog.execution.ir.* -import datalog.storage.{DB, EDB, KNOWLEDGE, StorageManager, StorageAggOp} +import datalog.storage.{DB, EDB, KNOWLEDGE, StorageManager, StorageAggOp, StorageComparison} import datalog.tools.Debug.debug import java.util.concurrent.{Executors, ForkJoinPool} @@ -53,16 +53,21 @@ class StagedExecutionEngine(val storageManager: StorageManager, val defaultJITOp get(storageManager.ns(name)) } - def insertIDB(rId: Int, ruleSeq: Seq[Atom]): Unit = { - precedenceGraph.addNode(ruleSeq) + def insertIDB(rId: Int, ruleSeq: Seq[Atom | Constraint]): Unit = { + val (atoms, constraints) = ruleSeq.partitionMap{ + case a: Atom => Left(a) + case c: Constraint => Right(c) + } + + precedenceGraph.addNode(atoms) // println(s"${storageManager.printer.ruleToString(ruleSeq)}") - var rule = ruleSeq - var k = JoinIndexes(rule, None, None) + var rule = atoms + var k = JoinIndexes(rule, constraints, None, None, None) storageManager.allRulesAllIndexes.getOrElseUpdate(rId, mutable.Map[String, JoinIndexes]()).addOne(k.hash, k) if (rule.length <= heuristics.max_length_cache) - val allK = JoinIndexes.allOrders(rule) + val allK = JoinIndexes.allOrders(rule, constraints) storageManager.allRulesAllIndexes(rId) ++= allK if (defaultJITOptions.sortOrder == SortOrder.Sel) // sort before inserting, just in case EDBs are defined @@ -77,7 +82,7 @@ class StagedExecutionEngine(val storageManager: StorageManager, val defaultJITOp -1 ) rule = rule.head +: sortedBody.map(_._1) - k = JoinIndexes(rule, Some(k.cxns), Some(k.groupingIndexes)) + k = JoinIndexes(rule, constraints, Some(k.cxns), Some((k.cons, k.pos2Term)), Some(k.groupingIndexes)) storageManager.allRulesAllIndexes(rId).addOne(k.hash, k) else if (defaultJITOptions.sortOrder == SortOrder.Badluck) // mimic "bad luck" program definition, so ingest rules in a bad order and then don't update them. val (sortedBody, newHash) = JoinIndexes.presortSelectWorst( @@ -91,7 +96,7 @@ class StagedExecutionEngine(val storageManager: StorageManager, val defaultJITOp -1 ) rule = rule.head +: sortedBody.map(_._1) - k = JoinIndexes(rule, Some(k.cxns), Some(k.groupingIndexes)) + k = JoinIndexes(rule, constraints, Some(k.cxns), Some((k.cons, k.pos2Term)), Some(k.groupingIndexes)) storageManager.allRulesAllIndexes(rId).addOne(k.hash, k) // println(s"${storageManager.printer.ruleToString(rule)}") @@ -133,8 +138,11 @@ class StagedExecutionEngine(val storageManager: StorageManager, val defaultJITOp case x: Variable => VarTerm(x) case x: Constant => ConstTerm(x) }, b.negated) + ) ++ constraints.map(c => + ConstraintAtom(c.c, c.l, c.r) ), rule, + constraints, k )) } @@ -349,8 +357,8 @@ class StagedExecutionEngine(val storageManager: StorageManager, val defaultJITOp case op: ScanEDBOp => op.run(storageManager) - case op: ComplementOp => - op.run(storageManager) + case op: NegationOp => + op.run_continuation(storageManager, op.children.map(o => (sm: StorageManager) => jit(o))) case op: ProjectJoinFilterOp => op.run_continuation(storageManager, op.children.map(o => (sm: StorageManager) => jit(o))) diff --git a/src/main/scala/datalog/execution/StagedSnippetCompiler.scala b/src/main/scala/datalog/execution/StagedSnippetCompiler.scala index 162a2933..4e035234 100644 --- a/src/main/scala/datalog/execution/StagedSnippetCompiler.scala +++ b/src/main/scala/datalog/execution/StagedSnippetCompiler.scala @@ -1,8 +1,8 @@ package datalog.execution -import datalog.dsl.{Atom, Constant, Term, Variable} +import datalog.dsl.{Atom, Constant, Term, Variable, Comparison, Expression, Constraint} import datalog.execution.ir.* -import datalog.storage.{DB, EDB, KNOWLEDGE, StorageManager, StorageAggOp} +import datalog.storage.{DB, EDB, KNOWLEDGE, StorageManager, StorageAggOp, StorageComparison, StorageExpression} import datalog.tools.Debug.debug import scala.collection.mutable @@ -56,6 +56,48 @@ class StagedSnippetCompiler(val storageManager: StorageManager)(using val jitOpt } } + given ToExpr[StorageComparison] with { + def apply(x: StorageComparison)(using Quotes) = { + x match + case StorageComparison.EQ => '{ StorageComparison.EQ } + case StorageComparison.NEQ => '{ StorageComparison.NEQ } + case StorageComparison.LT => '{ StorageComparison.LT } + case StorageComparison.LTE => '{ StorageComparison.LTE } + case StorageComparison.GT => '{ StorageComparison.GT } + case StorageComparison.GTE => '{ StorageComparison.GTE } + + } + } + + given ToExpr[Expression] with { + def apply(x: Expression)(using Quotes) = { + x match + case Expression.One(t) => '{ Expression.One( ${ Expr(t) } ) } + case Expression.Add(l, r) => '{ Expression.Add( ${ Expr(l) }, ${ Expr(r) } ) } + case Expression.Sub(l, r) => '{ Expression.Sub( ${ Expr(l) }, ${ Expr(r) } ) } + case Expression.Mul(l, r) => '{ Expression.Mul( ${ Expr(l) }, ${ Expr(r) } ) } + case Expression.Div(l, r) => '{ Expression.Div( ${ Expr(l) }, ${ Expr(r) } ) } + case Expression.Mod(l, r) => '{ Expression.Mod( ${ Expr(l) }, ${ Expr(r) } ) } + } + } + + given ToExpr[StorageExpression] with { + def apply(x: StorageExpression)(using Quotes) = { + x match + case StorageExpression.One(t) => '{ StorageExpression.One( ${ Expr(t) } ) } + case StorageExpression.Add(l, r) => '{ StorageExpression.Add( ${ Expr(l) }, ${ Expr(r) } ) } + case StorageExpression.Sub(l, r) => '{ StorageExpression.Sub( ${ Expr(l) }, ${ Expr(r) } ) } + case StorageExpression.Mul(l, r) => '{ StorageExpression.Mul( ${ Expr(l) }, ${ Expr(r) } ) } + case StorageExpression.Div(l, r) => '{ StorageExpression.Div( ${ Expr(l) }, ${ Expr(r) } ) } + case StorageExpression.Mod(l, r) => '{ StorageExpression.Mod( ${ Expr(l) }, ${ Expr(r) } ) } + } + } + + given ToExpr[Constraint] with { + def apply(x: Constraint)(using Quotes) = { + '{ Constraint( Comparison.fromOrdinal( ${ Expr(x.c.ordinal) } ), ${ Expr(x.l) }, ${ Expr(x.r) } ) } + } + } given ToExpr[JoinIndexes] with { def apply(x: JoinIndexes)(using Quotes) = { @@ -67,6 +109,9 @@ class StagedSnippetCompiler(val storageManager: StorageManager)(using val jitOpt ${ Expr(x.deps) }, ${ Expr(x.atoms) }, ${ Expr(x.cxns) }, + ${ Expr(x.cons) }, + ${ Expr(x.constraints) }, + ${ Expr(x.negationInfo) }, ${ Expr(x.edb) }, ) } } @@ -127,8 +172,13 @@ class StagedSnippetCompiler(val storageManager: StorageManager)(using val jitOpt } } - case ComplementOp(arity) => - '{ $stagedSM.getComplement(${ Expr(arity) }) } + case NegationOp(child, cols) => + val tmp = cols.map(_.exists(_.isEmpty)) + '{ + val compl = $stagedSM.getGroundOf(${ Expr(cols) }) + val nq = $stagedSM.zeroOut($stagedFns.head($stagedSM), ${ Expr(tmp) }) + $stagedSM.diff(compl, nq) + } case ScanEDBOp(rId) => if (storageManager.edbContains(rId)) diff --git a/src/main/scala/datalog/execution/StagedSnippetExecutionEngine.scala b/src/main/scala/datalog/execution/StagedSnippetExecutionEngine.scala index acd41e6d..fedcaba7 100644 --- a/src/main/scala/datalog/execution/StagedSnippetExecutionEngine.scala +++ b/src/main/scala/datalog/execution/StagedSnippetExecutionEngine.scala @@ -115,8 +115,8 @@ class StagedSnippetExecutionEngine(override val storageManager: StorageManager, case op: DebugPeek => op.run_continuation(storageManager, op.children.map(o => (sm: StorageManager) => jit(o))) - case op: ComplementOp => - op.run(storageManager) + case op: NegationOp => + op.run_continuation(storageManager, op.children.map(o => (sm: StorageManager) => jit(o))) case _ => throw new Exception(s"Error: interpretRelOp called with unit operation: code=${irTree.code}") } diff --git a/src/main/scala/datalog/execution/ast/ASTNode.scala b/src/main/scala/datalog/execution/ast/ASTNode.scala index 2f1f803f..bec7dc64 100644 --- a/src/main/scala/datalog/execution/ast/ASTNode.scala +++ b/src/main/scala/datalog/execution/ast/ASTNode.scala @@ -1,6 +1,6 @@ package datalog.execution.ast -import datalog.dsl.{Atom, Constant, Term, Variable} +import datalog.dsl.{Atom, Constant, Term, Variable, Comparison, Expression, Constraint} import datalog.execution.{JoinIndexes, GroupingJoinIndexes} import datalog.storage.{RelationId, StorageAggOp} @@ -17,7 +17,7 @@ abstract class AtomNode() extends ASTNode {} case class LogicAtom(relation: RelationId, terms: Seq[ASTNode], negated: Boolean) extends AtomNode {} -case class RuleNode(head: ASTNode, body: Seq[ASTNode], dslAtoms: Seq[Atom], currentK: JoinIndexes) extends ASTNode {} +case class RuleNode(head: ASTNode, body: Seq[ASTNode], dslAtoms: Seq[Atom], dslConstraints: Seq[Constraint], currentK: JoinIndexes) extends ASTNode {} abstract class TermNode(value: Term) extends ASTNode {} @@ -27,4 +27,6 @@ case class ConstTerm(value: Constant) extends TermNode(value) {} case class LogicGroupingAtom(gp: LogicAtom, gv: Seq[VarTerm], ags: Seq[(AggOpNode, VarTerm)]) extends AtomNode {} -case class AggOpNode(aggOp: StorageAggOp, term: TermNode) extends ASTNode {} \ No newline at end of file +case class AggOpNode(aggOp: StorageAggOp, term: TermNode) extends ASTNode {} + +case class ConstraintAtom(c: Comparison, l: Expression, r: Expression) extends ASTNode {} // Decompose expressions into subtrees diff --git a/src/main/scala/datalog/execution/ast/transform/CopyEliminationPass.scala b/src/main/scala/datalog/execution/ast/transform/CopyEliminationPass.scala index 17141093..f15bf5d9 100644 --- a/src/main/scala/datalog/execution/ast/transform/CopyEliminationPass.scala +++ b/src/main/scala/datalog/execution/ast/transform/CopyEliminationPass.scala @@ -22,7 +22,7 @@ class CopyEliminationPass()(using ASTTransformerContext) extends Transformer { case AllRulesNode(rules, _, edb) => if (rules.size == 1 && !edb) checkAlias(rules.head) - case RuleNode(head, body, _, _) => + case RuleNode(head, body, _, _, _) => if (body.size == 1) // for now just subst simple equality (head, body(0)) match { case (h: LogicAtom, b: LogicAtom) => @@ -44,7 +44,7 @@ class CopyEliminationPass()(using ASTTransformerContext) extends Transformer { ) // delete aliased rules case AllRulesNode(rules, rId, edb) => AllRulesNode(rules.map(transform), rId, edb) - case RuleNode(head, body, atoms, k) => + case RuleNode(head, body, atoms, constraints, k) => var aliased = false var newK = k val transformedAtoms = atoms.head +: atoms.drop(1).map(a => @@ -59,16 +59,16 @@ class CopyEliminationPass()(using ASTTransformerContext) extends Transformer { a ) if (aliased) - newK = JoinIndexes(transformedAtoms, None, None) + newK = JoinIndexes(transformedAtoms, constraints, None, None, None) ctx.sm.allRulesAllIndexes.getOrElseUpdate(transformedAtoms.head.rId, mutable.Map[String, JoinIndexes]()).addOne(newK.hash, newK) if (body.size < heuristics.max_length_cache) - val allK = JoinIndexes.allOrders(transformedAtoms) + val allK = JoinIndexes.allOrders(transformedAtoms, constraints) ctx.sm.allRulesAllIndexes(transformedAtoms.head.rId) ++= allK ctx.precedenceGraph.addNode(transformedAtoms) ctx.precedenceGraph.updateNodeAlias(ctx.aliases) - RuleNode(transform(head), body.map(transform), transformedAtoms, newK) + RuleNode(transform(head), body.map(transform), transformedAtoms, constraints, newK) case n: AtomNode => n match { case LogicAtom(relation, terms, neg) => LogicAtom(ctx.aliases.getOrElse(relation, relation), terms, neg) @@ -76,6 +76,7 @@ class CopyEliminationPass()(using ASTTransformerContext) extends Transformer { LogicGroupingAtom(LogicAtom(ctx.aliases.getOrElse(relation, relation), terms, neg), gv, ags) } case n: TermNode => n + case c: ConstraintAtom => c } else node diff --git a/src/main/scala/datalog/execution/ir/IROp.scala b/src/main/scala/datalog/execution/ir/IROp.scala index b1d490d9..0c7c1e24 100644 --- a/src/main/scala/datalog/execution/ir/IROp.scala +++ b/src/main/scala/datalog/execution/ir/IROp.scala @@ -18,7 +18,7 @@ import scala.util.{Failure, Success} enum OpCode: case PROGRAM, SWAP_CLEAR, SEQ, SCAN, SCANEDB, SCAN_DISCOVERED, - COMPLEMENT, + NEGATION, SPJ, INSERT, UNION, DIFF, GROUPING, DEBUG, DEBUGP, DOWHILE, UPDATE_DISCOVERED, @@ -196,14 +196,20 @@ case class InsertOp(rId: RelationId, db: DB, knowledge: KNOWLEDGE, override val } } -case class ComplementOp(arity: Int)(using JITOptions) extends IROp[EDB] { - val code: OpCode = OpCode.COMPLEMENT +case class NegationOp(child: IROp[EDB], cols: Seq[Either[Constant, Seq[(RelationId, Int)]]])(using JITOptions) extends IROp[EDB](child) { + val code: OpCode = OpCode.NEGATION override def run(storageManager: StorageManager): EDB = - storageManager.getComplement(arity) + val tmp = cols.map(_.exists(_.isEmpty)) + val compl = storageManager.getGroundOf(cols) + val nq = storageManager.zeroOut(child.run(storageManager), tmp) + storageManager.diff(compl, nq) override def run_continuation(storageManager: StorageManager, opFns: Seq[CompiledFn[EDB]]): EDB = - run(storageManager) // bc leaf node, no difference for continuation or run + val tmp = cols.map(_.exists(_.isEmpty)) + val compl = storageManager.getGroundOf(cols) + val nq = storageManager.zeroOut(opFns(0)(storageManager), tmp) + storageManager.diff(compl, nq) } case class ScanOp(rId: RelationId, db: DB, knowledge: KNOWLEDGE)(using JITOptions) extends IROp[EDB] { diff --git a/src/main/scala/datalog/execution/ir/IRTreeGenerator.scala b/src/main/scala/datalog/execution/ir/IRTreeGenerator.scala index e36ff2e9..c43d2a51 100644 --- a/src/main/scala/datalog/execution/ir/IRTreeGenerator.scala +++ b/src/main/scala/datalog/execution/ir/IRTreeGenerator.scala @@ -57,7 +57,7 @@ class IRTreeGenerator(using val ctx: InterpreterContext)(using JITOptions) { allRes = allRes :+ ScanEDBOp(rId) // TODO: potentially change this to Discovered not EDB // if(allRes.length == 1) allRes.head else UnionOp(OpCode.EVAL_RULE_NAIVE, allRes:_*) - case RuleNode(head, _, atoms, k) => + case RuleNode(head, _, atoms, _, k) => val r = head.asInstanceOf[LogicAtom].relation if (k.edb) ScanEDBOp(r) @@ -68,9 +68,9 @@ class IRTreeGenerator(using val ctx: InterpreterContext)(using JITOptions) { val q = ScanOp(r, DB.Derived, KNOWLEDGE.Known) typ match case PredicateType.NEGATED => - val arity = k.atoms(i + 1).terms.length - val res = DiffOp(ComplementOp(arity), q) - debug(s"found negated relation, rule=", () => s"${ctx.storageManager.printer.ruleToString(k.atoms)}\n\tarity=$arity") + val cols = k.negationInfo(k.atoms(i + 1).hash) + val res = NegationOp(q, cols) + debug(s"found negated relation, rule=", () => s"${ctx.storageManager.printer.ruleToString(k.atoms)}") res case PredicateType.GROUPING => val ga = k.atoms(i + 1).asInstanceOf[GroupingAtom] @@ -94,7 +94,7 @@ class IRTreeGenerator(using val ctx: InterpreterContext)(using JITOptions) { allRes = allRes :+ ScanEDBOp(rId) // if(allRes.length == 1) allRes.head else UnionOp(OpCode.EVAL_RULE_SN, allRes:_*) // None bc union of unions so no point in sorting - case RuleNode(head, body, atoms, k) => + case RuleNode(head, body, atoms, _, k) => val r = head.asInstanceOf[LogicAtom].relation if (k.edb) ScanEDBOp(r) @@ -119,9 +119,9 @@ class IRTreeGenerator(using val ctx: InterpreterContext)(using JITOptions) { ScanOp(r, DB.Derived, KNOWLEDGE.Known) typ match case PredicateType.NEGATED => - val arity = k.atoms(i + 1).terms.length - val res = DiffOp(ComplementOp(arity), q) - debug(s"found negated relation, rule=", () => s"${ctx.storageManager.printer.ruleToString(k.atoms)}\n\tarity=$arity") + val cols = k.negationInfo(k.atoms(i + 1).hash) + val res = NegationOp(q, cols) + debug(s"found negated relation, rule=", () => s"${ctx.storageManager.printer.ruleToString(k.atoms)}") res case PredicateType.GROUPING => val ga = k.atoms(i + 1).asInstanceOf[GroupingAtom] diff --git a/src/main/scala/datalog/storage/CollectionsStorageManager.scala b/src/main/scala/datalog/storage/CollectionsStorageManager.scala index 0384de86..ebaa979e 100644 --- a/src/main/scala/datalog/storage/CollectionsStorageManager.scala +++ b/src/main/scala/datalog/storage/CollectionsStorageManager.scala @@ -11,7 +11,6 @@ import scala.collection.{Iterator, immutable, mutable} abstract class CollectionsStorageManager(override val ns: NS) extends StorageManager(ns) { // "database", i.e. relationID => Relation protected val edbs: CollectionsDatabase = CollectionsDatabase() // raw user-supplied EDBs from initialization. - val edbDomain: mutable.Set[StorageTerm] = mutable.Set.empty // incrementally grow the total domain of all EDBs, used for calculating complement of negated predicates protected val discoveredFacts: CollectionsDatabase = CollectionsDatabase() // all EDBs + facts discovered in previous strata var knownDbId: KnowledgeId = -1 var newDbId: KnowledgeId = -1 @@ -65,48 +64,37 @@ abstract class CollectionsStorageManager(override val ns: NS) extends StorageMan else edbs(rule.rId) = CollectionsEDB() edbs(rule.rId).addOne(CollectionsRow(rule.terms)) - edbDomain.addAll(rule.terms) - } - /* Call when adding an IDB rule so domain can grow incrementally */ - override def addConstantsToDomain(constants: Seq[StorageTerm]): Unit = { - edbDomain.addAll(constants) } + def getEmptyEDB(): CollectionsEDB = CollectionsEDB() def getEDB(rId: RelationId): CollectionsEDB = edbs(rId) def edbContains(rId: RelationId): Boolean = edbs.contains(rId) def getAllEDBS(): mutable.Map[RelationId, Any] = edbs.wrapped.asInstanceOf[mutable.Map[RelationId, Any]] - /** - * Used for computing DOM(k) of a negated relation. Returns the (unchanging) set of possible EDB values + - * constants in all IDB rules. Currently unused because we incrementally add elements to the domain but may - * be useful if we want a domain containing only predicates from <= strata. - */ -// Comment out until we can track domain in something other than indexes -// def computeDomain(): Set[StorageTerm] = { -// val constants = mutable.Set[StorageTerm]() -// edbs.foreach((_, rows) => // avoid map or flatMap for CollectionsDatabase, CollectionRow -// rows.foreach(row => -// constants.addAll(row.toSeq) -// ) -// ) -// constants.addAll(allRulesAllIndexes.flatMap((_, allIndexes) => -// allIndexes.head._2.constIndexes.values -// )) -// constants.toSet -// } - - /** - * Compute Dom * Dom * ... arity # times - */ - override def getComplement(arity: Int): CollectionsEDB = { - // short but inefficient - val res = List.fill(arity)(edbDomain).flatten.combinations(arity).flatMap(_.permutations).toSeq - CollectionsEDB( - res.map(r => CollectionsRow(r.toSeq)):_* - ) + def getGroundOf(cols: Seq[Either[StorageConstant, Seq[(RelationId, Int)]]]): CollectionsEDB = { + val ctans = cols.collect{ case Right(value) => value }.flatten.groupBy(_._1).view.mapValues(v => v.map(_._2).distinct.map(col => + val der = getKnownDerivedDB(v.head._1).wrapped.map(r => r(col).asInstanceOf[StorageConstant]) + val del = getKnownDeltaDB(v.head._1).wrapped.map(r => r(col).asInstanceOf[StorageConstant]) + col -> (der.toSet ++ del.toSet) + ).toMap) + + val colCtans = cols.map{ + case Left(value) => Set(value) + case Right(value) => if value.isEmpty then Set(0) else value.map((rel, col) => ctans(rel)(col)).reduceLeft(_.intersect(_)) + } + + val first = CollectionsEDB(colCtans.head.map(c => CollectionsRow(Seq(c))).toSeq*) + colCtans.tail.foldLeft(first)((acc, s) => acc.flatMap(c => s.map(n => c.concat(CollectionsRow(Seq(n)))))) } + def zeroOut(input: EDB, cols: Seq[Boolean]): CollectionsEDB = + val tmp = asCollectionsEDB(input) + if cols.exists(identity) then + tmp.map(r => CollectionsRow(r.wrapped.zip(cols).map((v, a) => if a then 0 else v))).distinct() + else tmp + // Read intermediate results + def getKnownDerivedDB(rId: RelationId): CollectionsEDB = derivedDB(knownDbId).getOrElse(rId, discoveredFacts.getOrElse(rId, CollectionsEDB())) def getNewDerivedDB(rId: RelationId): CollectionsEDB = diff --git a/src/main/scala/datalog/storage/DefaultStorageManager.scala b/src/main/scala/datalog/storage/DefaultStorageManager.scala index 9bee586b..554eb9dd 100644 --- a/src/main/scala/datalog/storage/DefaultStorageManager.scala +++ b/src/main/scala/datalog/storage/DefaultStorageManager.scala @@ -27,26 +27,38 @@ class DefaultStorageManager(ns: NS = new NS()) extends CollectionsStorageManager val kCmp = k.constIndexes.isEmpty || k.constIndexes.forall((idx, const) => idx >= maxIdx || get(idx) == const ) - vCmp && kCmp + val cCmp = k.cons.isEmpty || k.cons.forall((o, sc, l, r, idx) => + o.exists(x => x) || idx >= maxIdx || { + val tpe = getType(get(idx).asInstanceOf[StorageConstant]) + val el = buildExpression(l, tpe) + val er = buildExpression(r, tpe) + val op = comparisons(sc)(tpe) + op(el(get), er(get)) + } + ) + + vCmp && kCmp && cCmp } override def joinHelper(inputEDB: Seq[EDB], k: JoinIndexes): CollectionsEDB = { - val inputs = asCollectionsSeqEDB(inputEDB) - inputs - .reduceLeft((outer: CollectionsEDB, inner: CollectionsEDB) => { - outer.flatMap(outerTuple => { - inner.flatMap(innerTuple => { - val get = (i: Int) => { - outerTuple.applyOrElse(i, j => innerTuple(j - outerTuple.length)) - } - if(scanFilter(k, innerTuple.length + outerTuple.length)(get)) - Some(outerTuple.concat(innerTuple)) - else - None + if k.cons.exists(_._1.exists(x => !x)) then getEmptyEDB() // Some constraint is false + else + val inputs = asCollectionsSeqEDB(inputEDB) + inputs + .reduceLeft((outer: CollectionsEDB, inner: CollectionsEDB) => { + outer.flatMap(outerTuple => { + inner.flatMap(innerTuple => { + val get = (i: Int) => { + outerTuple.applyOrElse(i, j => innerTuple(j - outerTuple.length)) + } + if(scanFilter(k, innerTuple.length + outerTuple.length)(get)) + Some(outerTuple.concat(innerTuple)) + else + None + }) }) }) - }) - .filter(r => scanFilter(k, r.length)(r.apply)) + .filter(r => scanFilter(k, r.length)(r.apply)) } override def projectHelper(input: EDB, k: JoinIndexes): CollectionsEDB = { @@ -79,6 +91,18 @@ class DefaultStorageManager(ns: NS = new NS()) extends CollectionsStorageManager ) } + private inline def consfilter(cons: Seq[(Option[Boolean], StorageComparison, StorageExpression, StorageExpression, Int)], tuple: CollectionsRow): Boolean = { + cons.isEmpty || cons.forall((o, sc, l, r, idx) => + o.exists(x => x) || idx >= tuple.length || { + val tpe = getType(tuple(idx).asInstanceOf[StorageConstant]) + val el = buildExpression(l, tpe) + val er = buildExpression(r, tpe) + val op = comparisons(sc)(tpe) + op(el(tuple.apply), er(tuple.apply)) + } + ) + } + override def joinProjectHelper_withHash(inputsEDB: Seq[EDB], rId: Int, hash: String, onlineSort: Boolean): CollectionsEDB = { val originalK = allRulesAllIndexes(rId)(hash) val inputs = asCollectionsSeqEDB(inputsEDB) @@ -88,6 +112,7 @@ class DefaultStorageManager(ns: NS = new NS()) extends CollectionsStorageManager .filter(e => val filteredC = originalK.constIndexes.filter((ind, _) => ind < e.length) prefilter(filteredC, 0, e) && filteredC.size == originalK.constIndexes.size) + .filter(r => consfilter(originalK.cons, r)) .map(t => CollectionsRow(originalK.projIndexes.flatMap((typ, idx) => typ match { @@ -109,8 +134,8 @@ class DefaultStorageManager(ns: NS = new NS()) extends CollectionsStorageManager val (inner, outer) = // on the fly swapping of join order if (atomI > 1 && onlineSort && outerT.length > innerT.length) val body = k.atoms.drop(1) - val newerHash = JoinIndexes.getRuleHash(Seq(k.atoms.head, body(atomI)) ++ body.dropRight(body.length - atomI) ++ body.drop(atomI + 1)) - k = allRulesAllIndexes(rId).getOrElseUpdate(newerHash, JoinIndexes(originalK.atoms.head +: body, Some(originalK.cxns), Some(originalK.groupingIndexes))) + val newerHash = JoinIndexes.getRuleHash(Seq(k.atoms.head, body(atomI)) ++ body.dropRight(body.length - atomI) ++ body.drop(atomI + 1), k.constraints) + k = allRulesAllIndexes(rId).getOrElseUpdate(newerHash, JoinIndexes(originalK.atoms.head +: body, originalK.constraints, Some(originalK.cxns), Some((originalK.cons, originalK.pos2Term)), Some(originalK.groupingIndexes))) (outerT, innerT) else (innerT, outerT) @@ -125,6 +150,7 @@ class DefaultStorageManager(ns: NS = new NS()) extends CollectionsStorageManager prefilter(k.constIndexes.filter((ind, _) => ind >= outerTuple.length && ind < (outerTuple.length + i.length)), outerTuple.length, i) && toJoin(k.varIndexes, outerTuple, i) ) .map(innerTuple => outerTuple.concat(innerTuple))) + .filter(r => consfilter(k.cons, r)) // intermediateCardinalities = intermediateCardinalities :+ edbResult.length (edbResult, atomI + 1, k) ) @@ -150,6 +176,7 @@ class DefaultStorageManager(ns: NS = new NS()) extends CollectionsStorageManager .filter(e => val filteredC = originalK.constIndexes.filter((ind, _) => ind < e.length) prefilter(filteredC, 0, e) && filteredC.size == originalK.constIndexes.size) + .filter(r => consfilter(originalK.cons, r)) .map(t => CollectionsRow(originalK.projIndexes.flatMap((typ, idx) => typ match { @@ -174,7 +201,7 @@ class DefaultStorageManager(ns: NS = new NS()) extends CollectionsStorageManager val (inner, outer) = if (atomI > 1 && onlineSort && outerT.length > innerT.length) val body = k.atoms.drop(1) - k = JoinIndexes(Seq(k.atoms.head, body(atomI)) ++ body.dropRight(body.length - atomI) ++ body.drop(atomI + 1), Some(originalK.cxns), Some(originalK.groupingIndexes)) + k = JoinIndexes(Seq(k.atoms.head, body(atomI)) ++ body.dropRight(body.length - atomI) ++ body.drop(atomI + 1), originalK.constraints, Some(originalK.cxns), Some((originalK.cons, originalK.pos2Term)), Some(originalK.groupingIndexes)) (outerT, innerT) else (innerT, outerT) @@ -188,6 +215,7 @@ class DefaultStorageManager(ns: NS = new NS()) extends CollectionsStorageManager prefilter(k.constIndexes.filter((ind, _) => ind >= outerTuple.length && ind < (outerTuple.length + i.length)), outerTuple.length, i) && toJoin(k.varIndexes, outerTuple, i) ) .map(innerTuple => outerTuple.concat(innerTuple))) + .filter(r => consfilter(k.cons, r)) (edbResult, atomI + 1, k) ) result._1 @@ -224,11 +252,7 @@ class DefaultStorageManager(ns: NS = new NS()) extends CollectionsStorageManager val reducers = gji.aggOpInfos.map(_._1).zip(tpes).map(aggOps(_)(_)) val okreducers = (a: CollectionsRow, b: CollectionsRow) => CollectionsRow(a.wrapped.zip(b.wrapped).zip(reducers).map((x, y) => y.apply(x._1.asInstanceOf[StorageConstant], x._2.asInstanceOf[StorageConstant]))) - val res = filteredBase.groupMapReduce(r => CollectionsRow(gji.groupingIndexes.map(r.apply)), okgetters, okreducers) - - val ctans = res.wrapped.map(_.wrapped.drop(gji.groupingIndexes.length)).flatten.toSeq - addConstantsToDomain(ctans) - res + filteredBase.groupMapReduce(r => CollectionsRow(gji.groupingIndexes.map(r.apply)), okgetters, okreducers) else getEmptyEDB() } @@ -263,10 +287,13 @@ class DefaultStorageManager(ns: NS = new NS()) extends CollectionsStorageManager typ match case PredicateType.NEGATED => - val arity = k.atoms(i + 1).terms.length - val compl = getComplement(arity) - val res = diff(compl, q) - debug("found negated relation, rule=", () => s"${printer.ruleToString(k.atoms)}\n\tarity=$arity, compl=${printer.factToString(compl)}, Q=${printer.factToString(q)}, final res=${printer.factToString(res)}") + val nis = k.negationInfo(k.atoms(i + 1).hash) + val cols = nis.map(_.exists(_.isEmpty)) + + val compl = getGroundOf(nis) + val nq = zeroOut(q, cols) + val res = diff(compl, nq) + debug("found negated relation, rule=", () => s"${printer.ruleToString(k.atoms)}\n\tcompl=${printer.factToString(compl)}, Q=${printer.factToString(q)}, final res=${printer.factToString(res)}") res case PredicateType.GROUPING => val ga = k.atoms(i + 1).asInstanceOf[GroupingAtom] @@ -292,10 +319,13 @@ class DefaultStorageManager(ns: NS = new NS()) extends CollectionsStorageManager val q = getKnownDerivedDB(r) typ match case PredicateType.NEGATED => - val arity = k.atoms(i + 1).terms.length - val compl = getComplement(arity) - val res = diff(compl, q) - debug(s"found negated relation, rule=", () => s"${printer.ruleToString(k.atoms)}\n\tarity=$arity, compl=${printer.factToString(compl)}, Q=${printer.factToString(q)}, final res=${printer.factToString(res)}") + val nis = k.negationInfo(k.atoms(i + 1).hash) + val cols = nis.map(_.exists(_.isEmpty)) + + val compl = getGroundOf(nis) + val nq = zeroOut(q, cols) + val res = diff(compl, nq) + debug(s"found negated relation, rule=", () => s"${printer.ruleToString(k.atoms)}\n\tcompl=${printer.factToString(compl)}, Q=${printer.factToString(q)}, final res=${printer.factToString(res)}") res case PredicateType.GROUPING => val ga = k.atoms(i + 1).asInstanceOf[GroupingAtom] diff --git a/src/main/scala/datalog/storage/Printer.scala b/src/main/scala/datalog/storage/Printer.scala index 01db2557..415f29f2 100644 --- a/src/main/scala/datalog/storage/Printer.scala +++ b/src/main/scala/datalog/storage/Printer.scala @@ -104,7 +104,7 @@ class Printer[S <: StorageManager](val sm: S) { node match { case ProgramNode(allRules) => "PROGRAM\n" + allRules.map((rId, rules) => s" ${sm.ns(rId)} => ${printAST(rules)}").mkString("", "\n", "") case AllRulesNode(rules, rId, edb) => s"${if (edb) "{EDB}"+factToString(sm.getEDB(rId))+"{IDB}" else ""}${rules.map(printAST).mkString("[", "\n\t", " ]")}" - case RuleNode(head, body, atoms, k) => + case RuleNode(head, body, atoms, constraints, k) => s"\n\t${printAST(head)} :- ${body.map(printAST).mkString("(", ", ", ")")}" + s" => idx=${k.toStringWithNS(sm.ns)}\n" case n: AtomNode => n match { @@ -141,7 +141,7 @@ class Printer[S <: StorageManager](val sm: S) { k.toStringWithNS(sm.ns)}::${ children.map(o => printIR(o, ident+1)).mkString("(\n", ",\n", ")")}" case DiffOp(children:_*) => s"DIFF\n${printIR(children.head, ident+1)}\n-${printIR(children(1), ident+1)}" - case ComplementOp(arity) => s"COMPL|$arity|" + case NegationOp(child, cols) => s"NEG|${printIR(child, ident+1)}| [$cols]" case DebugNode(prefix, dbg) => s"DEBUG: $prefix" case DebugPeek(prefix, dbg, children:_*) => s"DEBUG PEEK: $prefix into: ${printIR(children.head)}" }) diff --git a/src/main/scala/datalog/storage/StorageManager.scala b/src/main/scala/datalog/storage/StorageManager.scala index e6d798a9..d983a0dd 100644 --- a/src/main/scala/datalog/storage/StorageManager.scala +++ b/src/main/scala/datalog/storage/StorageManager.scala @@ -24,8 +24,8 @@ trait StorageManager(val ns: NS) { def getAllEDBS(): mutable.Map[RelationId, Any] // if you ever just want to read the EDBs as a map, used for testing // Helpers for negation - def addConstantsToDomain(constants: Seq[StorageTerm]): Unit - def getComplement(arity: Int): CollectionsEDB + def getGroundOf(cols: Seq[Either[StorageConstant, Seq[(RelationId, Int)]]]): EDB + def zeroOut(input: EDB, cols: Seq[Boolean]): EDB def getKnownDerivedDB(rId: RelationId): EDB def getNewDerivedDB(rId: RelationId): EDB diff --git a/src/main/scala/datalog/storage/StorageTypes.scala b/src/main/scala/datalog/storage/StorageTypes.scala index 4231d6e2..b3dfb6aa 100644 --- a/src/main/scala/datalog/storage/StorageTypes.scala +++ b/src/main/scala/datalog/storage/StorageTypes.scala @@ -94,3 +94,66 @@ val aggOps: Map[StorageAggOp, Map[Char, (StorageConstant, StorageConstant) => St 's' -> ((a, b) => if a.asInstanceOf[String] > b.asInstanceOf[String] then a.asInstanceOf[String] else b.asInstanceOf[String]) ) ) + +enum StorageComparison: + case EQ, NEQ, LT, LTE, GT, GTE + +val comparisons: Map[StorageComparison, Map[Char, (StorageConstant, StorageConstant) => Boolean]] = Map( + StorageComparison.EQ -> Map( + 'i' -> ((x, y) => x.asInstanceOf[Int] == y.asInstanceOf[Int]), + 's' -> ((x, y) => x.asInstanceOf[String] == y.asInstanceOf[String]) + ), + StorageComparison.NEQ -> Map( + 'i' -> ((x, y) => x.asInstanceOf[Int] != y.asInstanceOf[Int]), + 's' -> ((x, y) => x.asInstanceOf[String] != y.asInstanceOf[String]) + ), + StorageComparison.LT -> Map( + 'i' -> ((x, y) => x.asInstanceOf[Int] < y.asInstanceOf[Int]), + 's' -> ((x, y) => x.asInstanceOf[String] < y.asInstanceOf[String]) + ), + StorageComparison.LTE -> Map( + 'i' -> ((x, y) => x.asInstanceOf[Int] <= y.asInstanceOf[Int]), + 's' -> ((x, y) => x.asInstanceOf[String] <= y.asInstanceOf[String]) + ), + StorageComparison.GT -> Map( + 'i' -> ((x, y) => x.asInstanceOf[Int] > y.asInstanceOf[Int]), + 's' -> ((x, y) => x.asInstanceOf[String] > y.asInstanceOf[String]) + ), + StorageComparison.GTE -> Map( + 'i' -> ((x, y) => x.asInstanceOf[Int] >= y.asInstanceOf[Int]), + 's' -> ((x, y) => x.asInstanceOf[String] >= y.asInstanceOf[String]) + ) +) + +enum StorageExpression: + case One(t: Either[StorageConstant, Int]) + case Add(l: StorageExpression, r: Either[StorageConstant, Int]) + case Sub(l: StorageExpression, r: Either[StorageConstant, Int]) + case Mul(l: StorageExpression, r: Either[StorageConstant, Int]) + case Div(l: StorageExpression, r: Either[StorageConstant, Int]) + case Mod(l: StorageExpression, r: Either[StorageConstant, Int]) + +def buildExpression(se: StorageExpression, tpe: Char): (Int => StorageTerm) => StorageConstant = + import StorageExpression.* + tpe match + case 'i' => + def aux(se: StorageExpression, get: Int => StorageTerm): Int = + se match + case One(t) => t.fold(x => x.asInstanceOf[Int], x => get(x).asInstanceOf[Int]) + case Add(l, r) => aux(l, get) + r.fold(x => x.asInstanceOf[Int], x => get(x).asInstanceOf[Int]) + case Sub(l, r) => aux(l, get) - r.fold(x => x.asInstanceOf[Int], x => get(x).asInstanceOf[Int]) + case Mul(l, r) => aux(l, get) * r.fold(x => x.asInstanceOf[Int], x => get(x).asInstanceOf[Int]) + case Div(l, r) => aux(l, get) / r.fold(x => x.asInstanceOf[Int], x => get(x).asInstanceOf[Int]) + case Mod(l, r) => aux(l, get) % r.fold(x => x.asInstanceOf[Int], x => get(x).asInstanceOf[Int]) + g => aux(se, g) + case 's' => + def aux(se: StorageExpression, get: Int => StorageTerm): String = + se match + case One(t) => t.fold(x => x.asInstanceOf[String], x => get(x).asInstanceOf[String]) + case Add(l, r) => aux(l, get) + r.fold(x => x.asInstanceOf[String], x => get(x).asInstanceOf[String]) + case Sub(l, r) => ??? + case Mul(l, r) => ??? + case Div(l, r) => ??? + case Mod(l, r) => ??? + g => aux(se, g) + diff --git a/src/main/scala/datalog/storage/VolcanoOperators.scala b/src/main/scala/datalog/storage/VolcanoOperators.scala index 0b7d883e..f66869a2 100644 --- a/src/main/scala/datalog/storage/VolcanoOperators.scala +++ b/src/main/scala/datalog/storage/VolcanoOperators.scala @@ -255,4 +255,22 @@ class VolcanoOperators[S <: StorageManager](val storageManager: S) { } def close(): Unit = input.close() } + + case class Negation(input: VolOperator, cols: Seq[Either[StorageConstant, Seq[(RelationId, Int)]]]) extends VolOperator { + private var outputRelation: CollectionsEDB = CollectionsEDB() + private var index = 0 + def open(): Unit = + val tmp = cols.map(_.exists(_.isEmpty)) + val compl = storageManager.getGroundOf(cols) + val nq = storageManager.zeroOut(input.toList(), tmp) + outputRelation = asCollectionsEDB(storageManager.diff(compl, nq)) + def next(): Option[CollectionsRow] = { + if (index >= outputRelation.length) + NilTuple + else + index += 1 + Option(outputRelation(index - 1)) + } + def close(): Unit = input.close() + } } \ No newline at end of file diff --git a/src/main/scala/datalog/storage/VolcanoStorageManager.scala b/src/main/scala/datalog/storage/VolcanoStorageManager.scala index 75614a2a..ef3fe8f7 100644 --- a/src/main/scala/datalog/storage/VolcanoStorageManager.scala +++ b/src/main/scala/datalog/storage/VolcanoStorageManager.scala @@ -41,10 +41,9 @@ class VolcanoStorageManager(ns: NS = NS()) extends CollectionsStorageManager(ns) val q = Scan(getKnownDerivedDB(r), r) typ match case PredicateType.NEGATED => - val arity = k.atoms(i + 1).terms.length - val compl = getComplement(arity) - val res = Diff(Seq(Scan(compl, r), q)) - debug(s"found negated relation, rule=", () => s"${printer.ruleToString(k.atoms)}\n\tarity=$arity") + val cols = k.negationInfo(k.atoms(i + 1).hash) + val res = Negation(q, cols) + debug(s"found negated relation, rule=", () => s"${printer.ruleToString(k.atoms)}") res case PredicateType.GROUPING => val ga = k.atoms(i + 1).asInstanceOf[GroupingAtom] @@ -94,10 +93,9 @@ class VolcanoStorageManager(ns: NS = NS()) extends CollectionsStorageManager(ns) Scan(getKnownDerivedDB(r), r) typ match case PredicateType.NEGATED => - val arity = k.atoms(i + 1).terms.length - val compl = getComplement(arity) - val res = Diff(Seq(Scan(compl, r), q)) - debug(s"found negated relation, rule=", () => s"${printer.ruleToString(k.atoms)}\n\tarity=$arity") + val cols = k.negationInfo(k.atoms(i + 1).hash) + val res = Negation(q, cols) + debug(s"found negated relation, rule=", () => s"${printer.ruleToString(k.atoms)}") res case PredicateType.GROUPING => val ga = k.atoms(i + 1).asInstanceOf[GroupingAtom] diff --git a/src/test/scala/test/StratifiedNegationTests.scala b/src/test/scala/test/StratifiedNegationTests.scala index 50a58243..2552f61e 100644 --- a/src/test/scala/test/StratifiedNegationTests.scala +++ b/src/test/scala/test/StratifiedNegationTests.scala @@ -1,6 +1,6 @@ package test -import datalog.dsl.{Constant, Program, __, not} +import datalog.dsl.{Constant, Program, __, not, groupBy, AggOp} import datalog.execution.* import datalog.storage.DefaultStorageManager @@ -18,7 +18,7 @@ class StratifiedNegationTests extends munit.FunSuite { t(x, y) :- e(x, y) t(x, z) :- (t(x, y), e(y, z)) - interceptMessage[java.lang.Exception]("Variable with varId 0 appears only in negated rules") { + interceptMessage[java.lang.Exception]("Variable with varId 0 appears only in negated atoms (and possibly in aggregated positions of grouping atoms)") { tc(x, y) :- not(t(x, y)) // x and y are not limited. } } @@ -35,7 +35,7 @@ class StratifiedNegationTests extends munit.FunSuite { t(x, y) :- e(x, y) t(x, z) :- (t(x, y), e(y, z)) - interceptMessage[java.lang.Exception]("Variable with varId 2 appears only in negated rules") { + interceptMessage[java.lang.Exception]("Variable with varId 2 appears only in negated atoms (and possibly in aggregated positions of grouping atoms)") { tc(x, y) :- (e(x, y), !e(x, z), !e(x, z)) } } @@ -52,4 +52,18 @@ class StratifiedNegationTests extends munit.FunSuite { p.solve(e.id) } } + + test("stratified negation with aggregation") { + val p = Program(new StagedExecutionEngine(new DefaultStorageManager())) + val a = p.relation[Constant]("a") + val b = p.relation[Constant]("b") + val x, y = p.variable() + + a("A", 1) :- () + a("B", 2) :- () + + interceptMessage[java.lang.Exception]("Variable with varId 0 appears only in negated atoms (and possibly in aggregated positions of grouping atoms)") { + b(x) :- (!a(__, x), groupBy(a(__, y), Seq(), AggOp.SUM(y) -> x)) + } + } } diff --git a/src/test/scala/test/examples/array/array.scala b/src/test/scala/test/examples/array/array.scala new file mode 100644 index 00000000..3394cfea --- /dev/null +++ b/src/test/scala/test/examples/array/array.scala @@ -0,0 +1,53 @@ +package test.examples.array + +import buildinfo.BuildInfo +import datalog.dsl.* +import test.ExampleTestGenerator + +import java.nio.file.Paths + +class array_test extends ExampleTestGenerator("array") with array +trait array { + val factDirectory = s"${BuildInfo.baseDirectory}/src/test/scala/test/examples/array/facts" + val toSolve = "_" + def pretest(program: Program): Unit = { + val default = program.namedRelation("default") + val parameters = program.namedRelation("parameters") + val values = program.namedRelation("values") + + val dom = program.relation[Constant]("dom") + val indices = program.relation[Constant]("indices") + val element = program.relation[Constant]("element") + + val left = program.relation[Constant]("left") + val right = program.relation[Constant]("right") + val neighbourhood = program.relation[Constant]("neighbourhood") + + val i, j, k, l = program.variable() + + val tmp1 = program.relation[Constant]("tmp1") + val tmp2 = program.relation[Constant]("tmp2") + + (-1 to 100).foreach(x => + dom(x) :- () + ) + + indices(0) :- () + indices(i) :- (indices(j), parameters(k), dom(i), i |=| (j + 1), i |<| k) + + tmp1(i) :- values(i, __) + + element(i, j) :- (indices(i), !tmp1(i), default(j)) + element(i, j) :- (indices(i), values(i, j)) + + tmp2(i) :- element(i, __) + + left(i, j) :- (!tmp2(k), element(i, __), default(j), k |=| (i - 1), dom(k)) + left(i, j) :- (element(k, j), element(i, __), k |=| (i - 1)) + + right(i, j) :- (!tmp2(k), element(i, __), default(j), k |=| (i + 1), dom(k)) + right(i, j) :- (element(k, j), element(i, __), k |=| (i + 1)) + + neighbourhood(i, j, k, l) :- (left(i, j), element(i, k), right(i, l)) + } + } \ No newline at end of file diff --git a/src/test/scala/test/examples/array/expected/element.csv b/src/test/scala/test/examples/array/expected/element.csv new file mode 100644 index 00000000..78eb1273 --- /dev/null +++ b/src/test/scala/test/examples/array/expected/element.csv @@ -0,0 +1,11 @@ +Int Int +0 0 +4 0 +6 0 +7 0 +9 0 +1 13 +2 21 +3 34 +5 55 +8 89 diff --git a/src/test/scala/test/examples/array/expected/indices.csv b/src/test/scala/test/examples/array/expected/indices.csv new file mode 100644 index 00000000..d9e2f004 --- /dev/null +++ b/src/test/scala/test/examples/array/expected/indices.csv @@ -0,0 +1,11 @@ +Int +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 diff --git a/src/test/scala/test/examples/array/expected/neighbourhood.csv b/src/test/scala/test/examples/array/expected/neighbourhood.csv new file mode 100644 index 00000000..34814a38 --- /dev/null +++ b/src/test/scala/test/examples/array/expected/neighbourhood.csv @@ -0,0 +1,11 @@ +Int Int Int Int +0 0 0 13 +1 0 13 21 +5 0 55 0 +7 0 0 89 +8 0 89 0 +2 13 21 34 +3 21 34 0 +4 34 0 55 +6 55 0 0 +9 89 0 0 diff --git a/src/test/scala/test/examples/array/facts/default.facts b/src/test/scala/test/examples/array/facts/default.facts new file mode 100644 index 00000000..0dedbab9 --- /dev/null +++ b/src/test/scala/test/examples/array/facts/default.facts @@ -0,0 +1,2 @@ +Int +0 diff --git a/src/test/scala/test/examples/array/facts/parameters.facts b/src/test/scala/test/examples/array/facts/parameters.facts new file mode 100644 index 00000000..4a365a42 --- /dev/null +++ b/src/test/scala/test/examples/array/facts/parameters.facts @@ -0,0 +1,2 @@ +Int +10 diff --git a/src/test/scala/test/examples/array/facts/values.facts b/src/test/scala/test/examples/array/facts/values.facts new file mode 100644 index 00000000..1a40f077 --- /dev/null +++ b/src/test/scala/test/examples/array/facts/values.facts @@ -0,0 +1,6 @@ +Int Int +1 13 +2 21 +3 34 +5 55 +8 89 diff --git a/src/test/scala/test/examples/complement_explosion/complement_explosion.scala b/src/test/scala/test/examples/complement_explosion/complement_explosion.scala new file mode 100644 index 00000000..454addb9 --- /dev/null +++ b/src/test/scala/test/examples/complement_explosion/complement_explosion.scala @@ -0,0 +1,30 @@ +package test.examples.small + +import buildinfo.BuildInfo +import datalog.dsl.{Constant, Program} +import test.ExampleTestGenerator + +import java.nio.file.Paths + +class complement_explosion_test extends ExampleTestGenerator("complement_explosion") with complement_explosion +trait complement_explosion { + val factDirectory = s"${BuildInfo.baseDirectory}/src/test/scala/test/examples/complement_explosion/facts" + val toSolve = "derived" + def pretest(program: Program): Unit = { + val dom = program.relation[Constant]("dom") + + val base = program.relation[Constant]("base") + val derived = program.relation[Constant]("derived") + + val a, b, c, d, e, f = program.variable() + + (1 to 100).foreach(x => + dom(x) :- () + ) + + base(1, 2, 3, 4, 5, 6) :- () + base(11, 12, 13, 14, 15, 16) :- () + + derived(a, b, c, d, e, f) :- (base(a, b, c, d, e, f), !base(f, e, d, c, b, a)) + } +} diff --git a/src/test/scala/test/examples/complement_explosion/expected/derived.csv b/src/test/scala/test/examples/complement_explosion/expected/derived.csv new file mode 100644 index 00000000..ca520e42 --- /dev/null +++ b/src/test/scala/test/examples/complement_explosion/expected/derived.csv @@ -0,0 +1,3 @@ +Int Int Int Int Int Int +1 2 3 4 5 6 +11 12 13 14 15 16 diff --git a/src/test/scala/test/examples/dfa_min/dfa_min.scala b/src/test/scala/test/examples/dfa_min/dfa_min.scala new file mode 100644 index 00000000..13d309e8 --- /dev/null +++ b/src/test/scala/test/examples/dfa_min/dfa_min.scala @@ -0,0 +1,36 @@ +package test.examples.dfa_min + +import buildinfo.BuildInfo +import datalog.dsl.{Constant, Program, __, *} +import test.{ExampleTestGenerator, Tags} +class dfa_min_test extends ExampleTestGenerator("dfa_min") with dfa_min +trait dfa_min { + val factDirectory = s"${BuildInfo.baseDirectory}/src/test/scala/test/examples/dfa_min/facts" + val toSolve: String = "MinEquiv" + def pretest(program: Program): Unit = { + val Final = program.namedRelation("Final") + val Tr = program.namedRelation("Tr") + val Init = program.relation[Constant]("Init") + val Q = program.relation[Constant]("Q") + + val Dis = program.relation[Constant]("Dis") + val Equiv = program.relation[Constant]("Equiv") + val NotMinEquiv = program.relation[Constant]("NotMinEquiv") + val MinEquiv = program.relation[Constant]("MinEquiv") + + val q, r, a, s, t = program.variable() + + Init(0) :- () + + Q(q) :- Tr(q, __, __) + + Dis(q, r) :- (Q(q), Q(r), Final(q), !Final(r)) + Dis(q, r) :- (Tr(q, a, s), Tr(r, a, t), Dis(s, t)) + Dis(q, r) :- Dis(r, q) + + Equiv(q, r) :- (Q(q), Q(r), !Dis(q, r)) + + NotMinEquiv(q, r) :- (Equiv(q, r), Equiv(q, s), s |<| r) + MinEquiv(q, r) :- (Equiv(q, r), !NotMinEquiv(q, r)) + } +} diff --git a/src/test/scala/test/examples/dfa_min/expected/MinEquiv.csv b/src/test/scala/test/examples/dfa_min/expected/MinEquiv.csv new file mode 100644 index 00000000..c34d630f --- /dev/null +++ b/src/test/scala/test/examples/dfa_min/expected/MinEquiv.csv @@ -0,0 +1,6 @@ +Int Int +1 1 +2 2 +3 2 +4 4 +5 5 diff --git a/src/test/scala/test/examples/dfa_min/facts/Final.facts b/src/test/scala/test/examples/dfa_min/facts/Final.facts new file mode 100644 index 00000000..b62f487a --- /dev/null +++ b/src/test/scala/test/examples/dfa_min/facts/Final.facts @@ -0,0 +1,3 @@ +Int +1 +5 diff --git a/src/test/scala/test/examples/dfa_min/facts/Tr.facts b/src/test/scala/test/examples/dfa_min/facts/Tr.facts new file mode 100644 index 00000000..3508844c --- /dev/null +++ b/src/test/scala/test/examples/dfa_min/facts/Tr.facts @@ -0,0 +1,11 @@ +Int String Int +1 a 3 +1 b 4 +2 a 2 +2 b 1 +3 a 2 +3 b 1 +4 a 3 +4 b 5 +5 a 6 +5 b 1 diff --git a/src/test/scala/test/examples/inline_nats/expected/query.csv b/src/test/scala/test/examples/inline_nats/expected/query.csv new file mode 100644 index 00000000..4b133ce4 --- /dev/null +++ b/src/test/scala/test/examples/inline_nats/expected/query.csv @@ -0,0 +1,3 @@ +Int +0 +1 diff --git a/src/test/scala/test/examples/inline_nats/inline_nats.scala b/src/test/scala/test/examples/inline_nats/inline_nats.scala new file mode 100644 index 00000000..62dbc57f --- /dev/null +++ b/src/test/scala/test/examples/inline_nats/inline_nats.scala @@ -0,0 +1,24 @@ +package test.examples.inline_nats + +import buildinfo.BuildInfo +import datalog.dsl.{Constant, Program, __, *} +import test.{ExampleTestGenerator, Tags} +class inline_nats_test extends ExampleTestGenerator("inline_nats") with inline_nats +trait inline_nats { + val factDirectory = s"${BuildInfo.baseDirectory}/src/test/scala/test/examples/inline_nats/facts" + val toSolve: String = "query" + def pretest(program: Program): Unit = { + val nat = program.relation[Constant]("nat") + val query = program.relation[Constant]("query") + + val x, y = program.variable() + + (0 until 10000).foreach(x => + nat(x) :- () + ) + + query(x) :- (nat(x), x |<| 2) + + println({query.solve(); ()}) + } +} diff --git a/src/test/scala/test/examples/sequences/expected/palindrome.csv b/src/test/scala/test/examples/sequences/expected/palindrome.csv new file mode 100644 index 00000000..1ffae421 --- /dev/null +++ b/src/test/scala/test/examples/sequences/expected/palindrome.csv @@ -0,0 +1,26 @@ +Int +0 +1 +2 +3 +4 +8 +12 +13 +16 +19 +23 +26 +29 +33 +36 +39 +40 +52 +64 +68 +80 +92 +96 +108 +120 diff --git a/src/test/scala/test/examples/sequences/expected/read.csv b/src/test/scala/test/examples/sequences/expected/read.csv new file mode 100644 index 00000000..cee88a1e --- /dev/null +++ b/src/test/scala/test/examples/sequences/expected/read.csv @@ -0,0 +1,5 @@ +Int String +1 c +2 a +3 a +4 c diff --git a/src/test/scala/test/examples/sequences/sequences.scala b/src/test/scala/test/examples/sequences/sequences.scala new file mode 100644 index 00000000..b12def07 --- /dev/null +++ b/src/test/scala/test/examples/sequences/sequences.scala @@ -0,0 +1,129 @@ +package test.examples.sequences + +import buildinfo.BuildInfo +import datalog.dsl.{Constant, Program, __, *} +import test.{ExampleTestGenerator, Tags} +class sequences_test extends ExampleTestGenerator("sequences") with sequences +trait sequences { + val factDirectory = s"${BuildInfo.baseDirectory}/src/test/scala/test/examples/sequences/facts" + val toSolve: String = "_" + def pretest(program: Program): Unit = { + val num_letters = program.relation[Constant]("num_letters") + + val dom = program.relation[Constant]("dom") + + val a = program.relation[Constant]("a") + val n = program.relation[Constant]("n") + val s = program.relation[Constant]("s") + + val op_add = program.relation[Constant]("op_add") + val op_mul = program.relation[Constant]("op_mul") + val op_exp = program.relation[Constant]("op_exp") + val op_log = program.relation[Constant]("op_log") + val op_div = program.relation[Constant]("op_div") + val op_mod = program.relation[Constant]("op_mod") + + val idx = program.relation[Constant]("idx") + + val x, y, r, b, y2, z, px, py, pr, sr = program.variable() + + + num_letters(3) :- () + + a("a") :- () + a("b") :- () + a("c") :- () + + (0 to 120).foreach(o => + n(o) :- () + ) + + (0 to 120).foreach(o => + s(o, o+1) :- () + ) + + (0 to 121).foreach(o => + dom(o) :- () + ) + + op_add(x, 0, x) :- n(x) + op_add(x, y, r) :- (s(py, y), op_add(x, py, pr), s(pr, r)) + + op_mul(x, 0, 0) :- n(x) + op_mul(x, y, r) :- (s(py, y), op_mul(x, py, pr), op_add(pr, x, r)) + + + op_exp(x, 0, 1) :- n(x) + op_exp(x, y, r) :- (s(py, y), op_exp(x, py, pr), op_mul(pr, x, r)) + + op_log(x, b, r) :- (op_exp(b, r, y), s(r, sr), op_exp(b, sr, y2), n(x), y |<=| x, x |<| y2) + + op_div(x, y, r) :- (op_mul(r, y, sr), op_add(sr, z, x), 0 |<=| z, z |<| y) + + op_mod(x, y, r) :- (op_mul(y, __, z), op_add(z, r, x), 0 |<=| r, r |<| y) + + idx(0, "a") :- () + idx(1, "b") :- () + idx(2, "c") :- () + + val trie_letter = program.relation[Constant]("trie_letter") + val trie_level_end = program.relation[Constant]("trie_level_end") + val trie_level_start = program.relation[Constant]("trie_level_start") + val trie_level = program.relation[Constant]("trie_level") + val trie_parent = program.relation[Constant]("trie_parent") + val trie_root = program.relation[Constant]("trie_root") + val trie = program.relation[Constant]("trie") + + val pl, p, i, l, low, high, c, o = program.variable() + + trie_letter(z,b) :- (s(x,z), num_letters(sr), op_mod(x,sr,r), idx(r,b)) + + trie_level_end(0,0) :- () + trie_level_end(l,i) :- (num_letters(sr), s(pl,l), trie_level_end(pl,b), op_exp(sr,l,p), op_add(b,p,i)) + + trie_level_start(0,0) :- () + trie_level_start(l,i) :- (s(pl,l), trie_level_end(pl,b), op_add(b,1,i)) + + trie_level(0,0) :- () + trie_level(i,b) :- (n(i), s(z,b), trie_level_end(z,low), trie_level_end(b,high), low |<| i, i |<=| high) + + trie_parent(i,p) :- (num_letters(z), trie_level(i,l), s(pl,l), trie_level_start(l,b), op_add(b,x,i), op_div(x,z,o), trie_level_start(pl,c), op_add(c,o,p)) + + trie_root(0) :- () + + trie(x) :- trie_letter(x,__) + + val str = program.relation[Constant]("str") + val str_len = program.relation[Constant]("str_len") + val str_chain = program.relation[Constant]("str_chain") + val str_letter_at = program.relation[Constant]("str_letter_at") + + val id, sx, sy = program.variable() + + str(x) :- trie(x) + + str_len(id,l) :- trie_level(id, l) + + str_chain(id,id) :- trie(id) + str_chain(id,p) :- (str_chain(id,x), trie_parent(x,p)) + + str_letter_at(id,z,l) :- (str_chain(id,p), trie_level(p,z), trie_letter(p,l)) + + val palin_aux = program.relation[Constant]("palin_aux") + val palindrome = program.relation[Constant]("palindrome") + val debug_str = program.relation[Constant]("debug_str") + val read = program.relation[Constant]("read") + + palin_aux(b,x,x) :- (str(b), n(x), str_len(b,l), x |<=| l) + palin_aux(b,x,z) :- (str_letter_at(b, x, __), z |=| (x + 1), dom(z)) + palin_aux(b,x,sy) :- (str_letter_at(b,x,z), s(x,sx), palin_aux(b,sx,y), str_letter_at(b,y,z), s(y,sy)) + + palindrome(z) :- str_len(z,0) + palindrome(z) :- (palin_aux(z,1,sr), str_len(z,l), s(l,sr)) + + debug_str(96) :- () + + read(x,y) :- (debug_str(z), str_letter_at(z,x,y)) + + } +}