Skip to content

Commit

Permalink
test divmod enumerative sampler with BigInt
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Oct 14, 2023
1 parent 047d8af commit 0a923fc
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 37 deletions.
1 change: 1 addition & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ kotlin {
val multikVersion = "0.2.2"
implementation("org.jetbrains.kotlinx:multik-core:$multikVersion")
implementation("org.jetbrains.kotlinx:multik-default:$multikVersion")
implementation("com.ionspin.kotlin:bignum:0.3.8")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ val CFG.vindex: Array<IntArray> by cache {

val CFG.bindex: Bindex<Σᐩ> by cache { Bindex(nonterminals) }
val CFG.normalForm: CFG by cache { normalize() }
val CFG.graph: LabeledGraph by cache { dependencyGraph() }
val CFG.depGraph: LabeledGraph by cache { dependencyGraph() }
val CFG.revDepGraph: LabeledGraph by cache { revDependencyGraph() }

val CFG.originalForm: CFG by cache { rewriteHistory[this]?.get(0) ?: this }
val CFG.nonparametricForm: CFG by cache { rewriteHistory[this]!![1] }
Expand Down Expand Up @@ -235,7 +236,11 @@ fun CFG.forestHash(s: Σᐩ) = parseForest(s).structureEncode()
fun CFG.nonterminalHash(s: Σᐩ) = s.tokenizeByWhitespace().map { preimage(it) }.hashCode()
fun CFG.preimage(vararg nts: Σᐩ): Set<Σᐩ> = bimap.R2LHS[nts.toList()] ?: emptySet()

fun CFG.dependencyGraph() = LabeledGraph { forEach { prod -> prod.second.forEach { rhs -> prod.LHS - rhs } } }
fun CFG.dependencyGraph() =
LabeledGraph { forEach { prod -> prod.second.forEach { rhs -> prod.LHS - rhs } } }

fun CFG.revDependencyGraph() =
LabeledGraph { forEach { prod -> prod.second.forEach { rhs -> rhs - prod.LHS } } }

fun CFG.jsonify() = "cfg = {\n" +
bimap.L2RHS.entries.joinToString("\n") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ fun CJL.upwardClosure(terminals: Set<Σᐩ>): CJL =
fun CFG.upwardClosure(tokens: Set<Σᐩ>): CFG =
tokens.intersect(terminals).let {
if (it.isEmpty()) this
else (graph.reversed().transitiveClosure(tokens) - terminals)
else (depGraph.reversed().transitiveClosure(tokens) - terminals)
.let { closure -> filter { it.LHS in closure } }
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ fun Production.allSubSeq(nullables: Set<Σᐩ>): Set<Production> =

/**
* Makes ε-productions optional. n.b. We do not use CNF, but almost-CNF!
* ε-productions are allowed, because want to be able to synthesize them
* ε-productions are allowed because we want to be able to synthesize them
* as special characters, then simply omit them during printing.
*
* - Determine nullable variables, i.e., those which contain ε on the RHS
Expand Down Expand Up @@ -204,12 +204,12 @@ fun LabeledGraph.transitiveClosure(from: Set<Σᐩ>) =

// All symbols that are reachable from START_SYMBOL
fun CFG.reachableSymbols(from: Σᐩ = START_SYMBOL): Set<Σᐩ> =
reachability.getOrPut(from) { graph.transitiveClosure(setOf(from)) }
reachability.getOrPut(from) { depGraph.transitiveClosure(setOf(from)) }

// All symbols that are either terminals or generate terminals
fun CFG.generatingSymbols(
from: Set<Σᐩ> = terminalUnitProductions.map { it.LHS }.toSet(),
revGraph: LabeledGraph = graph.reversed()
revGraph: LabeledGraph = revDepGraph
): Set<Σᐩ> = revGraph.transitiveClosure(from)

/* Drops variable unit productions, for example:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,66 @@ package ai.hypergraph.kaliningraph.parsing

import ai.hypergraph.kaliningraph.tensor.UTMatrix
import ai.hypergraph.kaliningraph.types.*
import com.ionspin.kotlin.bignum.integer.*
import kotlin.time.measureTimedValue

fun PSingleton(v: String): List<Π2A<PTree>> = listOf(PTree(v) to PTree())

// Algebraic data type / polynomial functor for parse forests
class PTree(val root: String = "ε", val branches: List<Π2A<PTree>> = listOf()) {
val totalTrees: ULong by lazy {
if (branches.isEmpty()) 1uL
else branches.sumOf { (l, r) -> l.totalTrees * r.totalTrees }
// TODO: Use weighted choice mechanism
val shuffledBranches by lazy { branches.shuffled() }
val totalTrees: BigInteger by lazy {
if (branches.isEmpty()) BigInteger.ONE
else branches.map { (l, r) -> l.totalTrees * r.totalTrees }
.reduce { acc, it -> acc + it }
}

val depth: Int by lazy {
if (branches.isEmpty()) 0
else branches.maxOf { (l, r) -> maxOf(l.depth, r.depth) + 1 }
}

private val choice by lazy {
if (branches.isEmpty()) listOf(if ("ε" in root) "" else root)
else branches.shuffled().flatMap { (l, r) ->
// TODO: Use weighted choice mechanism
else shuffledBranches.flatMap { (l, r) ->
(l.choose() * r.choose()).map { (a, b) ->
if (a.isEmpty()) b else if (b.isEmpty()) a else "$a $b"
}
}.distinct().toList()
}.distinct()
}

private fun choose(): Sequence<String> = choice.asSequence()
fun choose(): Sequence<String> = choice.asSequence()

private fun decode(i: BigInteger): Pair<String, BigInteger> {
if (branches.isEmpty()) return (if ("ε" in root) "" else root) to i
val (quotient1, remainder) =
i.div(branches.size) to i.mod(branches.size.toBigInteger())
val (lb, rb) = shuffledBranches[remainder.toString().toInt()]
val (l, quotient2) = lb.decode(quotient1)
val (r, quotient3) = rb.decode(quotient2)
val concat = (if(l.isEmpty()) r else if(r.isEmpty()) l else "$l $r")
return concat to quotient3
}

// Returns the sequence of all strings derivable from the given PTree
// but needs a few seconds to warm up.
fun sampleWithoutReplacement(): Sequence<String> = choose()
fun sampleWithoutReplacement(): Sequence<String> = sequence {
println("Total trees in PTree: $totalTrees")
var i = BigInteger.ZERO
while (i < totalTrees) yield(decode(i++).first)
}

// Samples instantaneously from the parse forest, but may return duplicates
// and only returns a fraction of the number of distinct strings when compared
// to SWOR on medium-sized finite sets under the same wall-clock timeout. If
// the set is sufficiently large, distinctness will never be a problem.
fun sampleWithReplacement(): Sequence<String> = sequence { while(true) yield(sample()) }
fun sampleWithReplacement(): Sequence<String> = generateSequence { sample() }

fun sample(): String =
if (branches.isEmpty()) if ("ε" in root) "" else root
else branches.random().let { (l, r) ->
val (a, b) = l.sample() to r.sample()
if (a.isEmpty()) b else if (b.isEmpty()) a else "$a $b"
}

// TODO: Is there a sampler with no warmup that doesn't return duplicates?
// We want one that is as fast as SWR but with no dupes like SWOR.
}

fun CFG.startPTree(s: String) =
Expand Down Expand Up @@ -100,8 +118,15 @@ fun CFG.sliceSample(size: Int): Sequence<Σᐩ> =
sampleSeq(List(size) { "_" }.joinToString(" "))

// Lazily computes all syntactically strings compatible with the given template
// Generally slow, but guaranteed to return all solutions
fun CFG.solveSeq(s: String): Sequence<String> =
startPTree(s)?.choose()?.distinct() ?: sequenceOf()

// This should never return duplicates and is the second fastest.
// Eventually, this will become the default method for sampling.
fun CFG.enumSeq(s: String): Sequence<String> =
startPTree(s)?.sampleWithoutReplacement()?.distinct() ?: sequenceOf()

// This is generally the fastest method, but may return duplicates
fun CFG.sampleSeq(s: String): Sequence<String> =
startPTree(s)?.sampleWithReplacement() ?: sequenceOf()
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ operator fun <T> Set<T>.contains(s: Set<T>) = containsAll(s)
flatMap { s.map(it::to).toSet() }.toSet()

// IDK why the Kotlin stdlib provides these for Map but not Set
public inline fun <T> Set<T>.filter(predicate: (T) -> Boolean): Set<T> = filterTo(HashSet(), predicate)
public inline fun <T> Set<T>.filter(noinline predicate: (T) -> Boolean): Set<T> =
toMutableSet().apply { retainAll(predicate) }
// filterTo(HashSet(), predicate)
//public inline fun <T, Q> Set<T>.map(tx: (T) -> Q): Set<Q> = mapTo(HashSet(), tx)

interface VT<E, L: S<*>> : List<E> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class BarHillelTest {
assertFalse(testFail in levCFG.language)

val template = List(5) { "_" }.joinToString(" ")
val solutions = levCFG.solveSeq(template).toList().onEach { println(it) }
val solutions = levCFG.enumSeq(template).toList().onEach { println(it) }
println("Found ${solutions.size} solutions within Levenshtein distance 2 of \"$origStr\"")
}

Expand Down Expand Up @@ -185,4 +185,15 @@ class BarHillelTest {
" filtering should return the same solutions, but disjoint union was: " +
"${(lbhSet + efset) - (lbhSet intersect efset)}")
}

/*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.parsing.BarHillelTest.testPythonBarHillel"
*/
@Test
fun testPythonBarHillel() {
val gram = SetValiantTest.seq2parsePythonCFG.noEpsilonOrNonterminalStubs
gram.intersectLevFSA(makeLevFSA("1 + 2", 1, gram.terminals))
.enumSeq(List(5) { "_" }.joinToString(" "))
.onEach { println(it) }.toList()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -348,18 +348,24 @@ class SetValiantTest {
*/
@Test
fun testSeqValiant() {
val detSols =
seq2parsePythonCFG.solveSeq("_ _ _ _ _").sortedBy { it.length }.toList()
var clock = TimeSource.Monotonic.markNow()
val detSols = seq2parsePythonCFG.noEpsilonOrNonterminalStubs
.enumSeq(List(20) {"_"}.joinToString(" "))
.take(10_000).sortedBy { it.length }.toList()

detSols.forEach { assertTrue("\"$it\" was invalid!") { it in seq2parsePythonCFG.language } }
println("Found ${detSols.size} determinstic solutions, all were valid!")

val clock = TimeSource.Monotonic.markNow()
var elapsed = clock.elapsedNow().inWholeMilliseconds
println("Found ${detSols.size} determinstic solutions in ${elapsed}ms or ~${detSols.size / (elapsed/1000.0)}/s, all were valid!")

clock = TimeSource.Monotonic.markNow()
val randSols = seq2parsePythonCFG.noEpsilonOrNonterminalStubs
.sliceSample(20).take(10_000).toList()
.sliceSample(20).take(10_000).toList().distinct()
.onEach { assertTrue("\"$it\" was invalid!") { it in seq2parsePythonCFG.language } }

// 10k in ~22094ms
println("Found ${randSols.size} random solutions in ${clock.elapsedNow().inWholeMilliseconds}ms, all were valid!")
elapsed = clock.elapsedNow().inWholeMilliseconds
println("Found ${randSols.size} random solutions in ${elapsed}ms or ~${randSols.size / (elapsed/1000.0)}/s, all were valid!")
}

companion object {
Expand Down Expand Up @@ -593,14 +599,15 @@ Yield_Arg -> From_Keyword Test | Testlist_Endcomma
val refStr = "NAME = ( NAME"
val refLst = refStr.tokenizeByWhitespace()
val template = List(refLst.size + 3) { "_" }.joinToString(" ")
println("Solving: $template")
measureTime {
// seq2parsePythonCFG.solve(template, levMetric(refStr))
seq2parsePythonCFG.solveSeq(template)
seq2parsePythonCFG.enumSeq(template)
.map { it to levenshtein(it, refStr) }
.filter { it.second < 4 }.distinct()
.sortedWith(compareBy({ it.second }, { it.first.length })).toList()
.also { it.take(1000).forEach { println("Δ=${it.second}: ${it.first}") } }
.filter { it.second < 4 }.distinct().take(100)
.sortedWith(compareBy({ it.second }, { it.first.length }))
.onEach { println("Δ=${it.second}: ${it.first}") }
// .onEach { println("Δ=${levenshtein(it, refStr)}: $it") }
.toList()
.also { println("Found ${it.size} solutions!") }
}.also { println("Finished in ${it.inWholeMilliseconds}ms.") }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ infix fun SATRubix.valEq(that: SATRubix): Formula =
fun CFG.startFormula(ltop: SATVector, rtop: SATVector) =
startSymbols.map { bindex[it] }.map { ltop[it] eq rtop[it] }.reduce { acc, satf -> acc and satf }

fun CFG.downwardsReachabilitySeq() = graph
fun CFG.downwardsReachabilitySeq() = depGraph
.let { it.reachSequence(it.vertices.filter { it.label in startSymbols }.toSet()) }
.map { it.map { it.label }.toSet() }

fun CFG.upwardsReachabilitySeq() = graph
fun CFG.upwardsReachabilitySeq() = depGraph
.let { it.reachSequence(it.vertices.filter { it.label in terminals }.toSet(), it.A_AUG.transpose) }
.drop(1).map { it.map { it.label }.toSet() }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import ai.hypergraph.kaliningraph.types.*
import ai.hypergraph.kaliningraph.visualization.show
import org.junit.jupiter.api.Test
import kotlin.test.*
import kotlin.time.measureTime

/*
./gradlew jvmTest --tests "ai.hypergraph.kaliningraph.sat.SATValiantTest"
Expand Down Expand Up @@ -876,7 +875,7 @@ class SATValiantTest {
// }.also { println("Time: ${it}ms") }

println(arithCFG.originalForm.prettyPrint())
arithCFG.originalForm.graph.let {
arithCFG.originalForm.depGraph.let {
val start = it.vertices.first { it.label == "START" }
println(it.reachability(setOf(start), 2))
}
Expand Down

0 comments on commit 0a923fc

Please sign in to comment.