From 4f0e94acc2fc727985fb31b24d0408214cff63ab Mon Sep 17 00:00:00 2001 From: breandan Date: Mon, 24 Jun 2024 23:41:49 -0400 Subject: [PATCH] integrate GWA decoder into LBH repair --- .../ai/hypergraph/kaliningraph/parsing/CFG.kt | 12 +- .../kaliningraph/parsing/SeqValiant.kt | 4 +- .../hypergraph/kaliningraph/automata/JFSA.kt | 125 +++++++++++------- .../hypergraph/markovian/mcmc/MarkovChain.kt | 11 +- .../kaliningraph/automata/WFSATest.kt | 4 +- 5 files changed, 99 insertions(+), 57 deletions(-) diff --git a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/CFG.kt b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/CFG.kt index 70fdaba5..caa2b02f 100644 --- a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/CFG.kt +++ b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/CFG.kt @@ -6,6 +6,7 @@ import ai.hypergraph.kaliningraph.sampling.choose import ai.hypergraph.kaliningraph.tokenizeByWhitespace import ai.hypergraph.kaliningraph.types.* import kotlin.jvm.JvmName +import kotlin.random.Random import kotlin.time.* import kotlin.time.Duration.Companion.seconds @@ -388,4 +389,13 @@ fun CFG.jsonify() = "cfg = {\n" + ("\"${it.key}\": [${it.value.joinToString(", ") { it.joinToString(", ", "(", ")") { "\"$it\"" } }}],") - } + "\n}" \ No newline at end of file + } + "\n}" + +class TermDict( + val terms: Set<Σᐩ>, + val dict: Map = terms.associateBy { Random(it.hashCode()).nextInt().toChar() }, + val revDict: Map<Σᐩ, Char> = dict.entries.associate { (k, v) -> v to k } +) : Map by dict { + fun encode(str: String) = str.tokenizeByWhitespace().map { revDict[it]!! }.joinToString("") + fun encode(str: List) = str.map { revDict[it]!! }.joinToString("") +} \ No newline at end of file diff --git a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/SeqValiant.kt b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/SeqValiant.kt index 8887a6e2..ef414e83 100644 --- a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/SeqValiant.kt +++ b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/SeqValiant.kt @@ -24,11 +24,13 @@ class PTree(val root: String = ".ε", val branches: List<Π2A> = listOf() (1 + branches.sumOf { (l, r) -> l.branchRatio.second + r.branchRatio.second }) } - val allTerminals: Set by lazy { + val allTerminals: Set<Σᐩ> by lazy { if (branches.isEmpty()) setOf(root) else branches.map { (l, r) -> l.allTerminals + r.allTerminals }.flatten().toSet() } + val termDict by lazy { TermDict(allTerminals) } + // Σ^n/|T(n)|, if < 1 then we know the grammar is surely ambiguous val inverseDensity by lazy { measureTimedValue { allTerminals.size.toBigInteger().pow(depth) / totalTrees } diff --git a/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/automata/JFSA.kt b/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/automata/JFSA.kt index 34acc946..1832c6ca 100644 --- a/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/automata/JFSA.kt +++ b/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/automata/JFSA.kt @@ -1,11 +1,12 @@ package ai.hypergraph.kaliningraph.automata +import NUM_CORES import ai.hypergraph.kaliningraph.graphs.LabeledGraph import ai.hypergraph.kaliningraph.parsing.* import ai.hypergraph.markovian.mcmc.MarkovChain import dk.brics.automaton.Automaton.* import dk.brics.automaton.Transition -import java.util.* +import java.util.concurrent.* import kotlin.random.Random import kotlin.time.* @@ -49,31 +50,35 @@ fun JAutomaton.toDot(processed: MutableSet = mutableSetOf() * previous n-1 transitions, i.e., q' ~ argmax_{q'} P(q' | q_{t-1}, ..., q_{t-n+1}) */ -data class FSATrajectory(val toks: List<Σᐩ?>, val lastState: BState, val score: Double) { +data class FSATrajectory(val traj: List<Σᐩ?>, val lastState: BState, val score: Double) { val isComplete: Boolean = lastState.isAccept - override fun toString() = toks.reversed().filterNotNull().joinToString(" ") + val tokens by lazy { traj.reversed().filterNotNull() } + override fun toString() = tokens.joinToString(" ") } -fun PTree.decodeDFA(mc: MarkovChain<Σᐩ>, topK: Int = 10_000_000): List<Σᐩ> = propagator( - both = { a, b -> if (a == null) b else if (b == null) a else a.concatenate(b) }, - either = { a, b -> if (a == null) b else if (b == null) a else a.union(b) }, - unit = { a -> - if ("ε" in a.root) null - else BAutomaton.makeChar(Random(a.root.hashCode()).nextInt().toChar()) -// EditableAutomaton(RealSemiring()).apply { -// val s1 = addState(1.0, 0.0) -// val s2 = addState(0.0, 1.0) -// addTransition(s1, s2, a.root, 1.0) -// } - } -) -// ?.also { println("\n" + Operations.determinizeER(it).toDot().alsoCopy() + "\n") } -// .also { println("Total: ${Automata.transitions(it).size} arcs, ${Automata.states(it).size}") } -// .let { WAutomata.bestStrings(it, maxResults).map { it.label.joinToString(" ") }.toSet() } +fun BAutomaton.min(): BAutomaton = minimize(this) + +fun PTree.toDFA(minimize: Boolean = false) = + measureTimedValue { + BAutomaton.setMinimization(MINIMIZE_BRZOZOWSKI) + var i = 0 + var j = 0 + propagator( + both = { a, b -> if (a == null) b else if (b == null) a + // Only periodically minimize the automata during construction + else if (i++ % 13 == 0) a.concatenate(b).min() else a.concatenate(b) }, + either = { a, b -> if (a == null) b else if (b == null) a + else if (j++ % 13 == 0) a.union(b).min() else a.union(b) }, + unit = { a -> + if ("ε" in a.root) null + else BAutomaton.makeChar(Random(a.root.hashCode()).nextInt().toChar()) + } + ) + }.also { println("Took ${it.duration} to build FSA") }.value ?.also { println("Original automata had ${it .let { "${it.numberOfStates} states and ${it.numberOfTransitions} transitions"}}") } ?.also { - measureTimedValue { BAutomaton.setMinimization(MINIMIZE_BRZOZOWSKI); BAutomaton.minimize(it) } + if (minimize) measureTimedValue { BAutomaton.minimize(it) } .also { println("Minimization took ${it.duration}") }.value // .also { it.toDot().replaceAll(stbl).alsoCopy() } .also { @@ -83,48 +88,66 @@ fun PTree.decodeDFA(mc: MarkovChain<Σᐩ>, topK: Int = 10_000_000): List<Σᐩ> }") } } -// ?.getFiniteStrings(-1)?.map { it.map { ctbl[it] }.joinToString(" ") } ?: emptySet() - ?.steerableRandomWalk( - mc = mc, - dec = allTerminals.associateBy { Random(it.hashCode()).nextInt().toChar() }, - topK = topK - ) ?: emptyList() // Steers a random walk using the last n-1 transitions from the Markov Chain -fun BAutomaton.steerableRandomWalk( +fun BAutomaton.decodeDFA( mc: MarkovChain<Σᐩ>, // BAutomata uses a Unicode alphabet, and the Markov Chain recognizes a // string-based alphabet, so we need a way to translate between the two - dec: Map, // Maps unicode characters back to strings - topK: Int // Total number of top-K results to return + dec: Map, // Maps unicode characters back to strings because BAutomata uses Unicode + callback: (Σᐩ) -> Unit = {}, + topK: Int = 10_000_000, // Total number of top-K results to return + timeout: Duration = Duration.INFINITE, + parallelize: Boolean = false ): List<Σᐩ> { val startTime = TimeSource.Monotonic.markNow() - val fullTrajectories = PriorityQueue(compareBy { it.score / it.toks.size }) - val partTrajectories = PriorityQueue(compareBy { it.score / it.toks.size }) - .apply { add(FSATrajectory(List(mc.memory) { null }, initialState, 0.0)) } - while (fullTrajectories.size < topK && partTrajectories.isNotEmpty()) { - val partTraj = partTrajectories.remove() - val lastToks = partTraj.toks.take(mc.memory - 1).reversed() - partTraj.lastState.transitions.forEach { next: Transition -> - (next.min..next.max).forEach { tok -> - val decTok = dec[tok] - val nextToks = lastToks + decTok - val nextScore = partTraj.score + mc.scoreChunk(nextToks) - val traj = FSATrajectory(listOf(decTok) + partTraj.toks, next.dest, nextScore) - if (!traj.isComplete) partTrajectories.add(traj) - else { - fullTrajectories.add(traj) - if (traj.lastState.transitions.isNotEmpty()) - partTrajectories.add(traj) + val load = 100_000 + val fullTrajectories = PriorityBlockingQueue(load, compareBy { it.score / it.traj.size }) + val partTrajectories = Array(if(parallelize) NUM_CORES else 1) { + PriorityBlockingQueue(load, compareBy { it.score / it.traj.size }) + .apply { add(FSATrajectory(List(mc.memory) { null }, initialState, 0.0)) } + } + + fun task(id: Int = 0) { + var i = 0 + while ( + fullTrajectories.size < topK && + partTrajectories.any { it.size > 0 } && + startTime.elapsedNow() < timeout + ) { + if (partTrajectories[id].isEmpty()) continue +// Checks for balanced distribution of work across cores +// if (i++ % 9999 == 0) println("Trajectories[$id]: ${partTrajectories.map {it.size}}") + val partTraj = partTrajectories[id].remove() + val lastToks = partTraj.traj.take(mc.memory - 1).reversed() + partTraj.lastState.transitions.forEach { next: Transition -> + (next.min..next.max).forEach { tok -> + val decTok = dec[tok] + val nextToks = lastToks + decTok + val nextScore = partTraj.score + mc.scoreChunk(nextToks) + val traj = FSATrajectory(listOf(decTok) + partTraj.traj, next.dest, nextScore) + val bin = if (parallelize) Random(traj.score.hashCode()).nextInt(NUM_CORES) else 0 + if (!traj.isComplete) partTrajectories[bin].add(traj) + else { + fullTrajectories.add(traj) + callback(traj.toString()) + if (traj.lastState.transitions.isNotEmpty()) + partTrajectories[bin].add(traj) + } } } } } - println("Top 10 trajectories:") - fullTrajectories.take(10).forEach { println(it.score.toString().take(5) + ": $it") } - println("Took ${startTime.elapsedNow()} to decode ${fullTrajectories.size} trajectories") + if (parallelize) (0..( train: Sequence = sequenceOf(), val memory: Int = 3, - val counter: Counter = Counter(train, memory) + val counter: Counter = Counter(train, memory), + var scorePrefix: List = listOf(), + var scoreSuffix: List = listOf() ) { private val mgr = ResettableLazyManager() @@ -120,7 +122,7 @@ open class MarkovChain( // Computes perplexity of a sequence normalized by sequence length (lower is better) fun score(seq: List): Double = - if (memory < seq.size) -seq.windowed(memory) + if (memory < seq.size) -(scorePrefix + seq + scoreSuffix).windowed(memory) .map { (getAtLeastOne(it) + 1) / (getAtLeastOne(it.dropLast(1) + null) + dictionary.size) } .sumOf { ln(it) } / seq.size else (seq.sumOf { counter.rawCounts.getEstimate(it) } + 1).toDouble() / counter.total.toDouble() @@ -146,7 +148,10 @@ open class MarkovChain( } // https://www.cs.utah.edu/~jeffp/papers/merge-summ.pdf - operator fun plus(mc: MarkovChain) = MarkovChain(memory = memory, counter = counter + mc.counter) + operator fun plus(mc: MarkovChain) = MarkovChain( + memory = memory, counter = counter + mc.counter, + scorePrefix = scorePrefix, scoreSuffix = scoreSuffix + ) /** * TODO: construct [Dist] using precomputed normalization constants [Counter.nrmCounts] diff --git a/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/automata/WFSATest.kt b/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/automata/WFSATest.kt index 745002e6..942b7e5e 100644 --- a/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/automata/WFSATest.kt +++ b/src/jvmTest/kotlin/ai/hypergraph/kaliningraph/automata/WFSATest.kt @@ -21,6 +21,7 @@ class WFSATest { // readBIFIContents() val csv = File(File("").absolutePath + "/src/jvmTest/resources/ngrams_BIFI_$MARKOV_MEMORY.csv") MarkovChain.deserialize(csv.readText()) + .apply { scorePrefix = listOf("BOS", "NEWLINE"); scoreSuffix = listOf("EOS") } .also { println("Loaded ${it.counter.total} BIFI $MARKOV_MEMORY-grams from ${csv.absolutePath}") } } @@ -29,6 +30,7 @@ class WFSATest { val P_PY150: MarkovChain<Σᐩ> by lazy { val csv = File(File("").absolutePath + "/src/jvmTest/resources/ngrams_PY150_$MARKOV_MEMORY.csv") MarkovChain.deserialize(csv.readText()) + .apply { scorePrefix = listOf("BOS", "NEWLINE"); scoreSuffix = listOf("EOS") } .also { println("Loaded ${it.counter.total} PY150 $MARKOV_MEMORY-grams from ${csv.absolutePath}") } } @@ -88,7 +90,7 @@ class WFSATest { val ptreeRepairs = measureTimedValue { pt.sampleStrWithoutReplacement().distinct().take(maxResults).toSet() } - measureTimedValue { pt.decodeDFA(P_BIFI_PY150) }.also { + measureTimedValue { pt.toDFA()!!.decodeDFA(P_BIFI_PY150, dec = pt.termDict, parallelize = true) }.also { assertTrue(groundTr in it.value, "Ground truth not found in ${it.value.size} repairs") println("Index: ${it.value.indexOf(groundTr)}") // // Print side by side comparison of repairs