Skip to content

Commit

Permalink
Add pattern matching support for list random access patterns (#1143)
Browse files Browse the repository at this point in the history
This is not rigorously tested, nor is it likely to be in the near future
because the code it was written to support ended up not being worth
merging. However, it is broadly useful code that may find value in the
future, so I would like to get it merged.

It also includes a few fixes to python scripts in the repo that had gone
slightly stale.
  • Loading branch information
Dwight Guth authored Sep 11, 2024
1 parent 9d21e4d commit dafe4ec
Show file tree
Hide file tree
Showing 21 changed files with 2,666 additions and 113 deletions.
23 changes: 23 additions & 0 deletions debug/kgdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def __init__(self, val, cat, sortName):
self.used_var_names = set()
self.long_int = gdb.lookup_type("long int")
self.bool_ptr = gdb.lookup_type("bool").pointer()
self.char_ptr = gdb.lookup_type("unsigned char").pointer()
self.unsigned_char = gdb.lookup_type("unsigned char")
self.string_ptr = gdb.lookup_type("string").pointer()
self.stringbuffer_ptr = gdb.lookup_type("stringbuffer").pointer()
Expand Down Expand Up @@ -483,6 +484,20 @@ def appendInt(self, val, sort):
self.appendLimbs(size, val.dereference()['_mp_d'])
self.result += "\")"

def appendMInt(self, val, width, sort):
self.result += "\\dv{" + sort + "}(\""
self.appendLE(val.cast(self.char_ptr), width)
self.result += "\")"

def appendLE(self, ptr, size):
accum = 0
for i in range(size-1,-1,-1):
accum <<= 8
byte = int(ptr[i])
accum |= byte
self.result += str(accum)
self.result += "p" + str(size * 8)

def appendList(self, val, sort):
length = val.dereference()['impl_']['size']
if length == 0:
Expand Down Expand Up @@ -632,6 +647,14 @@ def append(self, subject, isVar, sort):
self.result += "\\dv{" + sort + "}(\"" + string + "\")"
elif cat == @STRINGBUFFER_LAYOUT@:
self.appendStringBuffer(arg.cast(self.stringbuffer_ptr_ptr).dereference(), sort)
elif cat == @MINT_LAYOUT@ + 32:
self.appendMInt(arg, 4, sort)
elif cat == @MINT_LAYOUT@ + 64:
self.appendMInt(arg, 8, sort)
elif cat == @MINT_LAYOUT@ + 160:
self.appendMInt(arg, 20, sort)
elif cat == @MINT_LAYOUT@ + 256:
self.appendMInt(arg, 32, sort)
else:
raise ValueError()
if i != nargs - 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import java.util.Optional;

public final class MatchingException extends Throwable {
public final class MatchingException extends RuntimeException {
public enum Type {
USELESS_RULE,
NON_EXHAUSTIVE_MATCH,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,62 +35,89 @@ case class NonEmpty() extends Constructor {
override lazy val hashCode: Int = scala.runtime.ScalaRunTime._hashCode(this)
}

case class HasKey(isSet: Boolean, element: SymbolOrAlias, key: Option[Pattern[Option[Occurrence]]])
extends Constructor {
case class HasKey(
cat: SortCategory,
element: SymbolOrAlias,
key: Option[Pattern[Option[Occurrence]]]
) extends Constructor {
def name = "1"
def isBest(pat: Pattern[Option[Occurrence]]): Boolean = key.isDefined && pat == key.get
def expand(f: Fringe): Option[immutable.Seq[Fringe]] = {
val sorts = f.symlib.signatures(element)._1
key match {
case None =>
if (isSet) {
Some(
immutable.Seq(
Fringe(f.symlib, sorts.head, Choice(f.occurrence), isExact = false),
Fringe(f.symlib, f.sort, ChoiceRem(f.occurrence), isExact = false)
cat match {
case SetS() =>
Some(
immutable.Seq(
Fringe(f.symlib, sorts.head, Choice(f.occurrence), isExact = false),
Fringe(f.symlib, f.sort, ChoiceRem(f.occurrence), isExact = false)
)
)
case MapS() =>
Some(
immutable.Seq(
Fringe(f.symlib, sorts.head, Choice(f.occurrence), isExact = false),
Fringe(f.symlib, sorts(1), ChoiceValue(f.occurrence), isExact = false),
Fringe(f.symlib, f.sort, ChoiceRem(f.occurrence), isExact = false)
)
)
)
} else {
Some(
immutable.Seq(
Fringe(f.symlib, sorts.head, Choice(f.occurrence), isExact = false),
Fringe(f.symlib, sorts(1), ChoiceValue(f.occurrence), isExact = false),
Fringe(f.symlib, f.sort, ChoiceRem(f.occurrence), isExact = false)
case ListS() =>
Some(
immutable.Seq(
Fringe(f.symlib, sorts(1), Choice(f.occurrence), isExact = false),
Fringe(f.symlib, sorts(2), ChoiceValue(f.occurrence), isExact = false),
Fringe(f.symlib, f.sort, ChoiceRem(f.occurrence), isExact = false)
)
)
)
case _ => ???
}
case Some(k) =>
if (isSet) {
Some(immutable.Seq(Fringe(f.symlib, f.sort, Rem(k, f.occurrence), isExact = false), f))
} else {
Some(
immutable.Seq(
Fringe(f.symlib, sorts(1), Value(k, f.occurrence), isExact = false),
Fringe(f.symlib, f.sort, Rem(k, f.occurrence), isExact = false),
f
cat match {
case SetS() =>
Some(immutable.Seq(Fringe(f.symlib, f.sort, Rem(k, f.occurrence), isExact = false), f))
case MapS() =>
Some(
immutable.Seq(
Fringe(f.symlib, sorts(1), Value(k, f.occurrence), isExact = false),
Fringe(f.symlib, f.sort, Rem(k, f.occurrence), isExact = false),
f
)
)
case ListS() =>
Some(
immutable.Seq(
Fringe(f.symlib, sorts(2), Value(k, f.occurrence), isExact = false),
Fringe(f.symlib, f.sort, f.occurrence, isExact = false),
f
)
)
)
case _ => ???
}
}
}
def contract(f: Fringe, children: immutable.Seq[Pattern[String]]): Pattern[String] = {
val child = children.last
var key: Pattern[String] = null
var value: Pattern[String] = null
assert((isSet && children.size == 2) || (!isSet && children.size == 3))
assert((cat == SetS() && children.size == 2) || (cat != SetS() && children.size == 3))
if (this.key.isEmpty) {
if (isSet) {
key = children.head
} else {
key = children.head
value = children(1)
cat match {
case SetS() =>
key = children.head
case MapS() =>
key = children.head
value = children(1)
case _ => ???
}
} else {
if (isSet) {
key = this.key.get.decanonicalize
} else {
key = this.key.get.decanonicalize
value = children.head
cat match {
case SetS() =>
key = this.key.get.decanonicalize
case ListS() | MapS() =>
key = this.key.get.decanonicalize
value = children.head
case _ => ???
}
}
def element(k: Pattern[String], v: Pattern[String]): Pattern[String] =
Expand All @@ -99,35 +126,53 @@ case class HasKey(isSet: Boolean, element: SymbolOrAlias, key: Option[Pattern[Op
SymbolP(Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "element").get, immutable.Seq(k))
def concat(m1: Pattern[String], m2: Pattern[String]): Pattern[String] =
SymbolP(Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "concat").get, immutable.Seq(m1, m2))
def update(m1: Pattern[String], m2: Pattern[String], m3: Pattern[String]): Pattern[String] =
SymbolP(
Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "update").get,
immutable.Seq(m1, m2, m3)
)
child match {
case MapP(keys, values, frame, ctr, orig) =>
MapP(key +: keys, value +: values, frame, ctr, orig)
case ListGetP(keys, values, frame, ctr, orig) =>
ListGetP(key +: keys, value +: values, frame, ctr, orig)
case SetP(elems, frame, ctr, orig) =>
SetP(key +: elems, frame, ctr, orig)
case WildcardP() | VariableP(_, _) =>
if (isSet) {
SetP(
immutable.Seq(key),
Some(child),
Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "element").get,
concat(setElement(key), child)
)
} else {
MapP(
immutable.Seq(key),
immutable.Seq(value),
Some(child),
Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "element").get,
concat(element(key, value), child)
)
cat match {
case SetS() =>
SetP(
immutable.Seq(key),
Some(child),
Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "element").get,
concat(setElement(key), child)
)
case MapS() =>
MapP(
immutable.Seq(key),
immutable.Seq(value),
Some(child),
Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "element").get,
concat(element(key, value), child)
)
case ListS() =>
ListGetP(
immutable.Seq(key),
immutable.Seq(value),
child,
Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "update").get,
update(child, key, value)
)
case _ => ???
}
case _ => ???
}
}
override lazy val hashCode: Int = scala.runtime.ScalaRunTime._hashCode(this)
}

case class HasNoKey(isSet: Boolean, key: Option[Pattern[Option[Occurrence]]]) extends Constructor {
case class HasNoKey(cat: SortCategory, key: Option[Pattern[Option[Occurrence]]])
extends Constructor {
def name = "0"
def isBest(pat: Pattern[Option[Occurrence]]): Boolean = key.isDefined && pat == key.get
def expand(f: Fringe): Option[immutable.Seq[Fringe]] = Some(immutable.Seq(f))
Expand All @@ -141,6 +186,11 @@ case class HasNoKey(isSet: Boolean, key: Option[Pattern[Option[Occurrence]]]) ex
SymbolP(Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "unit").get, immutable.Seq())
def concat(m1: Pattern[String], m2: Pattern[String]): Pattern[String] =
SymbolP(Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "concat").get, immutable.Seq(m1, m2))
def update(m1: Pattern[String], m2: Pattern[String], m3: Pattern[String]): Pattern[String] =
SymbolP(
Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "update").get,
immutable.Seq(m1, m2, m3)
)
def wildcard = WildcardP[String]()
child match {
case MapP(keys, values, frame, ctr, orig) =>
Expand All @@ -154,21 +204,31 @@ case class HasNoKey(isSet: Boolean, key: Option[Pattern[Option[Occurrence]]]) ex
case SetP(elems, frame, ctr, orig) =>
SetP(wildcard +: elems, frame, ctr, concat(setElement(wildcard), orig))
case WildcardP() | VariableP(_, _) =>
if (isSet) {
SetP(
immutable.Seq(wildcard),
Some(child),
Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "element").get,
concat(setElement(wildcard), child)
)
} else {
MapP(
immutable.Seq(wildcard),
immutable.Seq(wildcard),
Some(child),
Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "element").get,
concat(element(wildcard, wildcard), child)
)
cat match {
case SetS() =>
SetP(
immutable.Seq(wildcard),
Some(child),
Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "element").get,
concat(setElement(wildcard), child)
)
case MapS() =>
MapP(
immutable.Seq(wildcard),
immutable.Seq(wildcard),
Some(child),
Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "element").get,
concat(element(wildcard, wildcard), child)
)
case ListS() =>
ListGetP(
immutable.Seq(wildcard),
immutable.Seq(wildcard),
child,
Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "update").get,
update(child, wildcard, wildcard)
)
case _ => ???
}
case _ => ???
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import com.runtimeverification.k.kore._
import org.kframework.backend.llvm.matching.dt.DecisionTree
import org.kframework.backend.llvm.matching.pattern.{ Pattern => P }
import org.kframework.backend.llvm.matching.pattern.AsP
import org.kframework.backend.llvm.matching.pattern.ListGetP
import org.kframework.backend.llvm.matching.pattern.ListP
import org.kframework.backend.llvm.matching.pattern.LiteralP
import org.kframework.backend.llvm.matching.pattern.MapP
Expand Down Expand Up @@ -51,6 +52,25 @@ object Generator {
case _ => ???
}

private def listGetPattern(
sym: SymbolOrAlias,
ps: immutable.Seq[P[String]],
c: SymbolOrAlias
): P[String] =
ps match {
case immutable.Seq(p @ (WildcardP() | VariableP(_, _)), k, v) =>
ListGetP(immutable.Seq(k), immutable.Seq(v), p, c, SymbolP(sym, immutable.Seq(p, k, v)))
case immutable.Seq(ListGetP(ks, vs, frame, _, o), k, v) =>
ListGetP(
ks ++ immutable.Seq(k),
vs ++ immutable.Seq(v),
frame,
c,
SymbolP(sym, immutable.Seq(o, k, v))
)
case _ => ???
}

private def mapPattern(
sym: SymbolOrAlias,
cons: CollectionCons,
Expand Down Expand Up @@ -116,6 +136,8 @@ object Generator {
): List[P[String]] = {
def getElementSym(sort: Sort): SymbolOrAlias =
Parser.getSymbolAtt(symlib.sortAtt(sort), "element").get
def getUpdateSym(sort: Sort): SymbolOrAlias =
Parser.getSymbolAtt(symlib.sortAtt(sort), "update").get
def genPattern(pat: Pattern): P[String] =
pat match {
case Application(sym, ps) =>
Expand All @@ -128,6 +150,8 @@ object Generator {
case Some("LIST.unit") => listPattern(sym, Unit(), immutable.Seq(), getElementSym(sort))
case Some("LIST.element") =>
listPattern(sym, Element(), ps.map(genPattern), getElementSym(sort))
case Some("LIST.update") =>
listGetPattern(sym, ps.map(genPattern), getUpdateSym(sort))
case Some("MAP.concat") =>
mapPattern(sym, Concat(), ps.map(genPattern), getElementSym(sort))
case Some("MAP.unit") => mapPattern(sym, Unit(), immutable.Seq(), getElementSym(sort))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ sealed trait Heuristic {
key: Option[Pattern[Option[Occurrence]]],
isEmpty: Boolean
): Double = ???
def scoreListGet[T](
p: ListGetP[T],
f: Fringe,
c: Clause,
key: Option[Pattern[Option[Occurrence]]],
isEmpty: Boolean
): Double = ???
def scoreOr[T](
p: OrP[T],
f: Fringe,
Expand Down Expand Up @@ -148,6 +155,20 @@ object DefaultHeuristic extends Heuristic {
} else {
1.0
}
override def scoreListGet[T](
p: ListGetP[T],
f: Fringe,
c: Clause,
key: Option[Pattern[Option[Occurrence]]],
isEmpty: Boolean
): Double =
if (p.keys.isEmpty) {
p.frame.score(this, f, c, key, isEmpty)
} else if (key.isDefined) {
if (p.canonicalize(c).keys.contains(key.get)) 1.0 else 0.0
} else {
1.0
}

override def scoreOr[T](
p: OrP[T],
Expand Down
Loading

0 comments on commit dafe4ec

Please sign in to comment.