Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pattern matching support for list random access patterns #1143

Merged
merged 16 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions debug/kgdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def __init__(self, val, cat, sortName):
self.used_var_names = set()
self.long_int = gdb.lookup_type("long int")
self.bool_ptr = gdb.lookup_type("bool").pointer()
self.char_ptr = gdb.lookup_type("unsigned char").pointer()
self.unsigned_char = gdb.lookup_type("unsigned char")
self.string_ptr = gdb.lookup_type("string").pointer()
self.stringbuffer_ptr = gdb.lookup_type("stringbuffer").pointer()
Expand Down Expand Up @@ -483,6 +484,20 @@ def appendInt(self, val, sort):
self.appendLimbs(size, val.dereference()['_mp_d'])
self.result += "\")"

def appendMInt(self, val, width, sort):
self.result += "\\dv{" + sort + "}(\""
self.appendLE(val.cast(self.char_ptr), width)
self.result += "\")"

def appendLE(self, ptr, size):
accum = 0
for i in range(size-1,-1,-1):
accum <<= 8
byte = int(ptr[i])
accum |= byte
self.result += str(accum)
self.result += "p" + str(size * 8)

def appendList(self, val, sort):
length = val.dereference()['impl_']['size']
if length == 0:
Expand Down Expand Up @@ -632,6 +647,14 @@ def append(self, subject, isVar, sort):
self.result += "\\dv{" + sort + "}(\"" + string + "\")"
elif cat == @STRINGBUFFER_LAYOUT@:
self.appendStringBuffer(arg.cast(self.stringbuffer_ptr_ptr).dereference(), sort)
elif cat == @MINT_LAYOUT@ + 32:
self.appendMInt(arg, 4, sort)
elif cat == @MINT_LAYOUT@ + 64:
self.appendMInt(arg, 8, sort)
elif cat == @MINT_LAYOUT@ + 160:
self.appendMInt(arg, 20, sort)
elif cat == @MINT_LAYOUT@ + 256:
self.appendMInt(arg, 32, sort)
else:
raise ValueError()
if i != nargs - 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import java.util.Optional;

public final class MatchingException extends Throwable {
public final class MatchingException extends RuntimeException {
public enum Type {
USELESS_RULE,
NON_EXHAUSTIVE_MATCH,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,62 +35,89 @@ case class NonEmpty() extends Constructor {
override lazy val hashCode: Int = scala.runtime.ScalaRunTime._hashCode(this)
}

case class HasKey(isSet: Boolean, element: SymbolOrAlias, key: Option[Pattern[Option[Occurrence]]])
extends Constructor {
case class HasKey(
cat: SortCategory,
element: SymbolOrAlias,
key: Option[Pattern[Option[Occurrence]]]
) extends Constructor {
def name = "1"
def isBest(pat: Pattern[Option[Occurrence]]): Boolean = key.isDefined && pat == key.get
def expand(f: Fringe): Option[immutable.Seq[Fringe]] = {
val sorts = f.symlib.signatures(element)._1
key match {
case None =>
if (isSet) {
Some(
immutable.Seq(
Fringe(f.symlib, sorts.head, Choice(f.occurrence), isExact = false),
Fringe(f.symlib, f.sort, ChoiceRem(f.occurrence), isExact = false)
cat match {
case SetS() =>
Some(
immutable.Seq(
Fringe(f.symlib, sorts.head, Choice(f.occurrence), isExact = false),
Fringe(f.symlib, f.sort, ChoiceRem(f.occurrence), isExact = false)
)
)
case MapS() =>
Some(
immutable.Seq(
Fringe(f.symlib, sorts.head, Choice(f.occurrence), isExact = false),
Fringe(f.symlib, sorts(1), ChoiceValue(f.occurrence), isExact = false),
Fringe(f.symlib, f.sort, ChoiceRem(f.occurrence), isExact = false)
)
)
)
} else {
Some(
immutable.Seq(
Fringe(f.symlib, sorts.head, Choice(f.occurrence), isExact = false),
Fringe(f.symlib, sorts(1), ChoiceValue(f.occurrence), isExact = false),
Fringe(f.symlib, f.sort, ChoiceRem(f.occurrence), isExact = false)
case ListS() =>
Some(
immutable.Seq(
Fringe(f.symlib, sorts(1), Choice(f.occurrence), isExact = false),
Fringe(f.symlib, sorts(2), ChoiceValue(f.occurrence), isExact = false),
Fringe(f.symlib, f.sort, ChoiceRem(f.occurrence), isExact = false)
)
)
)
case _ => ???
}
case Some(k) =>
if (isSet) {
Some(immutable.Seq(Fringe(f.symlib, f.sort, Rem(k, f.occurrence), isExact = false), f))
} else {
Some(
immutable.Seq(
Fringe(f.symlib, sorts(1), Value(k, f.occurrence), isExact = false),
Fringe(f.symlib, f.sort, Rem(k, f.occurrence), isExact = false),
f
cat match {
case SetS() =>
Some(immutable.Seq(Fringe(f.symlib, f.sort, Rem(k, f.occurrence), isExact = false), f))
case MapS() =>
Some(
immutable.Seq(
Fringe(f.symlib, sorts(1), Value(k, f.occurrence), isExact = false),
Fringe(f.symlib, f.sort, Rem(k, f.occurrence), isExact = false),
f
)
)
case ListS() =>
Some(
immutable.Seq(
Fringe(f.symlib, sorts(2), Value(k, f.occurrence), isExact = false),
Fringe(f.symlib, f.sort, f.occurrence, isExact = false),
f
)
)
)
case _ => ???
}
}
}
def contract(f: Fringe, children: immutable.Seq[Pattern[String]]): Pattern[String] = {
val child = children.last
var key: Pattern[String] = null
var value: Pattern[String] = null
assert((isSet && children.size == 2) || (!isSet && children.size == 3))
assert((cat == SetS() && children.size == 2) || (cat != SetS() && children.size == 3))
if (this.key.isEmpty) {
if (isSet) {
key = children.head
} else {
key = children.head
value = children(1)
cat match {
case SetS() =>
key = children.head
case MapS() =>
key = children.head
value = children(1)
case _ => ???
}
} else {
if (isSet) {
key = this.key.get.decanonicalize
} else {
key = this.key.get.decanonicalize
value = children.head
cat match {
case SetS() =>
key = this.key.get.decanonicalize
case ListS() | MapS() =>
key = this.key.get.decanonicalize
value = children.head
case _ => ???
}
}
def element(k: Pattern[String], v: Pattern[String]): Pattern[String] =
Expand All @@ -99,35 +126,53 @@ case class HasKey(isSet: Boolean, element: SymbolOrAlias, key: Option[Pattern[Op
SymbolP(Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "element").get, immutable.Seq(k))
def concat(m1: Pattern[String], m2: Pattern[String]): Pattern[String] =
SymbolP(Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "concat").get, immutable.Seq(m1, m2))
def update(m1: Pattern[String], m2: Pattern[String], m3: Pattern[String]): Pattern[String] =
SymbolP(
Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "update").get,
immutable.Seq(m1, m2, m3)
)
child match {
case MapP(keys, values, frame, ctr, orig) =>
MapP(key +: keys, value +: values, frame, ctr, orig)
case ListGetP(keys, values, frame, ctr, orig) =>
ListGetP(key +: keys, value +: values, frame, ctr, orig)
case SetP(elems, frame, ctr, orig) =>
SetP(key +: elems, frame, ctr, orig)
case WildcardP() | VariableP(_, _) =>
if (isSet) {
SetP(
immutable.Seq(key),
Some(child),
Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "element").get,
concat(setElement(key), child)
)
} else {
MapP(
immutable.Seq(key),
immutable.Seq(value),
Some(child),
Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "element").get,
concat(element(key, value), child)
)
cat match {
case SetS() =>
SetP(
immutable.Seq(key),
Some(child),
Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "element").get,
concat(setElement(key), child)
)
case MapS() =>
MapP(
immutable.Seq(key),
immutable.Seq(value),
Some(child),
Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "element").get,
concat(element(key, value), child)
)
case ListS() =>
ListGetP(
immutable.Seq(key),
immutable.Seq(value),
child,
Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "update").get,
update(child, key, value)
)
case _ => ???
}
case _ => ???
}
}
override lazy val hashCode: Int = scala.runtime.ScalaRunTime._hashCode(this)
}

case class HasNoKey(isSet: Boolean, key: Option[Pattern[Option[Occurrence]]]) extends Constructor {
case class HasNoKey(cat: SortCategory, key: Option[Pattern[Option[Occurrence]]])
extends Constructor {
def name = "0"
def isBest(pat: Pattern[Option[Occurrence]]): Boolean = key.isDefined && pat == key.get
def expand(f: Fringe): Option[immutable.Seq[Fringe]] = Some(immutable.Seq(f))
Expand All @@ -141,6 +186,11 @@ case class HasNoKey(isSet: Boolean, key: Option[Pattern[Option[Occurrence]]]) ex
SymbolP(Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "unit").get, immutable.Seq())
def concat(m1: Pattern[String], m2: Pattern[String]): Pattern[String] =
SymbolP(Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "concat").get, immutable.Seq(m1, m2))
def update(m1: Pattern[String], m2: Pattern[String], m3: Pattern[String]): Pattern[String] =
SymbolP(
Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "update").get,
immutable.Seq(m1, m2, m3)
)
def wildcard = WildcardP[String]()
child match {
case MapP(keys, values, frame, ctr, orig) =>
Expand All @@ -154,21 +204,31 @@ case class HasNoKey(isSet: Boolean, key: Option[Pattern[Option[Occurrence]]]) ex
case SetP(elems, frame, ctr, orig) =>
SetP(wildcard +: elems, frame, ctr, concat(setElement(wildcard), orig))
case WildcardP() | VariableP(_, _) =>
if (isSet) {
SetP(
immutable.Seq(wildcard),
Some(child),
Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "element").get,
concat(setElement(wildcard), child)
)
} else {
MapP(
immutable.Seq(wildcard),
immutable.Seq(wildcard),
Some(child),
Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "element").get,
concat(element(wildcard, wildcard), child)
)
cat match {
case SetS() =>
SetP(
immutable.Seq(wildcard),
Some(child),
Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "element").get,
concat(setElement(wildcard), child)
)
case MapS() =>
MapP(
immutable.Seq(wildcard),
immutable.Seq(wildcard),
Some(child),
Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "element").get,
concat(element(wildcard, wildcard), child)
)
case ListS() =>
ListGetP(
immutable.Seq(wildcard),
immutable.Seq(wildcard),
child,
Parser.getSymbolAtt(f.symlib.sortAtt(f.sort), "update").get,
update(child, wildcard, wildcard)
)
case _ => ???
}
case _ => ???
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import com.runtimeverification.k.kore._
import org.kframework.backend.llvm.matching.dt.DecisionTree
import org.kframework.backend.llvm.matching.pattern.{ Pattern => P }
import org.kframework.backend.llvm.matching.pattern.AsP
import org.kframework.backend.llvm.matching.pattern.ListGetP
import org.kframework.backend.llvm.matching.pattern.ListP
import org.kframework.backend.llvm.matching.pattern.LiteralP
import org.kframework.backend.llvm.matching.pattern.MapP
Expand Down Expand Up @@ -51,6 +52,25 @@ object Generator {
case _ => ???
}

private def listGetPattern(
sym: SymbolOrAlias,
ps: immutable.Seq[P[String]],
c: SymbolOrAlias
): P[String] =
ps match {
case immutable.Seq(p @ (WildcardP() | VariableP(_, _)), k, v) =>
ListGetP(immutable.Seq(k), immutable.Seq(v), p, c, SymbolP(sym, immutable.Seq(p, k, v)))
case immutable.Seq(ListGetP(ks, vs, frame, _, o), k, v) =>
ListGetP(
ks ++ immutable.Seq(k),
vs ++ immutable.Seq(v),
frame,
c,
SymbolP(sym, immutable.Seq(o, k, v))
)
case _ => ???
}

private def mapPattern(
sym: SymbolOrAlias,
cons: CollectionCons,
Expand Down Expand Up @@ -116,6 +136,8 @@ object Generator {
): List[P[String]] = {
def getElementSym(sort: Sort): SymbolOrAlias =
Parser.getSymbolAtt(symlib.sortAtt(sort), "element").get
def getUpdateSym(sort: Sort): SymbolOrAlias =
Parser.getSymbolAtt(symlib.sortAtt(sort), "update").get
def genPattern(pat: Pattern): P[String] =
pat match {
case Application(sym, ps) =>
Expand All @@ -128,6 +150,8 @@ object Generator {
case Some("LIST.unit") => listPattern(sym, Unit(), immutable.Seq(), getElementSym(sort))
case Some("LIST.element") =>
listPattern(sym, Element(), ps.map(genPattern), getElementSym(sort))
case Some("LIST.update") =>
listGetPattern(sym, ps.map(genPattern), getUpdateSym(sort))
case Some("MAP.concat") =>
mapPattern(sym, Concat(), ps.map(genPattern), getElementSym(sort))
case Some("MAP.unit") => mapPattern(sym, Unit(), immutable.Seq(), getElementSym(sort))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ sealed trait Heuristic {
key: Option[Pattern[Option[Occurrence]]],
isEmpty: Boolean
): Double = ???
def scoreListGet[T](
p: ListGetP[T],
f: Fringe,
c: Clause,
key: Option[Pattern[Option[Occurrence]]],
isEmpty: Boolean
): Double = ???
def scoreOr[T](
p: OrP[T],
f: Fringe,
Expand Down Expand Up @@ -148,6 +155,20 @@ object DefaultHeuristic extends Heuristic {
} else {
1.0
}
override def scoreListGet[T](
p: ListGetP[T],
f: Fringe,
c: Clause,
key: Option[Pattern[Option[Occurrence]]],
isEmpty: Boolean
): Double =
if (p.keys.isEmpty) {
p.frame.score(this, f, c, key, isEmpty)
} else if (key.isDefined) {
if (p.canonicalize(c).keys.contains(key.get)) 1.0 else 0.0
} else {
1.0
}

override def scoreOr[T](
p: OrP[T],
Expand Down
Loading
Loading