Skip to content

Commit

Permalink
integrate GWA decoder into LBH repair
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Jun 25, 2024
1 parent 72889cb commit 4f0e94a
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 57 deletions.
12 changes: 11 additions & 1 deletion src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/CFG.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import ai.hypergraph.kaliningraph.sampling.choose
import ai.hypergraph.kaliningraph.tokenizeByWhitespace
import ai.hypergraph.kaliningraph.types.*
import kotlin.jvm.JvmName
import kotlin.random.Random
import kotlin.time.*
import kotlin.time.Duration.Companion.seconds

Expand Down Expand Up @@ -388,4 +389,13 @@ fun CFG.jsonify() = "cfg = {\n" +
("\"${it.key}\": [${it.value.joinToString(", ") {
it.joinToString(", ", "(", ")") { "\"$it\"" }
}}],")
} + "\n}"
} + "\n}"

class TermDict(
val terms: Set<Σᐩ>,
val dict: Map<Char, Σᐩ> = terms.associateBy { Random(it.hashCode()).nextInt().toChar() },
val revDict: Map<Σᐩ, Char> = dict.entries.associate { (k, v) -> v to k }
) : Map<Char, Σᐩ> by dict {
fun encode(str: String) = str.tokenizeByWhitespace().map { revDict[it]!! }.joinToString("")
fun encode(str: List<String>) = str.map { revDict[it]!! }.joinToString("")
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ class PTree(val root: String = ".ε", val branches: List<Π2A<PTree>> = listOf()
(1 + branches.sumOf { (l, r) -> l.branchRatio.second + r.branchRatio.second })
}

val allTerminals: Set<String> by lazy {
val allTerminals: Set<Σᐩ> by lazy {
if (branches.isEmpty()) setOf(root)
else branches.map { (l, r) -> l.allTerminals + r.allTerminals }.flatten().toSet()
}

val termDict by lazy { TermDict(allTerminals) }

// Σ^n/|T(n)|, if < 1 then we know the grammar is surely ambiguous
val inverseDensity by lazy {
measureTimedValue { allTerminals.size.toBigInteger().pow(depth) / totalTrees }
Expand Down
125 changes: 74 additions & 51 deletions src/jvmMain/kotlin/ai/hypergraph/kaliningraph/automata/JFSA.kt
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package ai.hypergraph.kaliningraph.automata

import NUM_CORES
import ai.hypergraph.kaliningraph.graphs.LabeledGraph
import ai.hypergraph.kaliningraph.parsing.*
import ai.hypergraph.markovian.mcmc.MarkovChain
import dk.brics.automaton.Automaton.*
import dk.brics.automaton.Transition
import java.util.*
import java.util.concurrent.*
import kotlin.random.Random
import kotlin.time.*

Expand Down Expand Up @@ -49,31 +50,35 @@ fun JAutomaton<String, Double>.toDot(processed: MutableSet<Any> = mutableSetOf()
* previous n-1 transitions, i.e., q' ~ argmax_{q'} P(q' | q_{t-1}, ..., q_{t-n+1})
*/

data class FSATrajectory(val toks: List<Σᐩ?>, val lastState: BState, val score: Double) {
data class FSATrajectory(val traj: List<Σᐩ?>, val lastState: BState, val score: Double) {
val isComplete: Boolean = lastState.isAccept
override fun toString() = toks.reversed().filterNotNull().joinToString(" ")
val tokens by lazy { traj.reversed().filterNotNull() }
override fun toString() = tokens.joinToString(" ")
}

fun PTree.decodeDFA(mc: MarkovChain<Σᐩ>, topK: Int = 10_000_000): List<Σᐩ> = propagator(
both = { a, b -> if (a == null) b else if (b == null) a else a.concatenate(b) },
either = { a, b -> if (a == null) b else if (b == null) a else a.union(b) },
unit = { a ->
if ("ε" in a.root) null
else BAutomaton.makeChar(Random(a.root.hashCode()).nextInt().toChar())
// EditableAutomaton<String, Double>(RealSemiring()).apply {
// val s1 = addState(1.0, 0.0)
// val s2 = addState(0.0, 1.0)
// addTransition(s1, s2, a.root, 1.0)
// }
}
)
// ?.also { println("\n" + Operations.determinizeER(it).toDot().alsoCopy() + "\n") }
// .also { println("Total: ${Automata.transitions(it).size} arcs, ${Automata.states(it).size}") }
// .let { WAutomata.bestStrings(it, maxResults).map { it.label.joinToString(" ") }.toSet() }
fun BAutomaton.min(): BAutomaton = minimize(this)

fun PTree.toDFA(minimize: Boolean = false) =
measureTimedValue {
BAutomaton.setMinimization(MINIMIZE_BRZOZOWSKI)
var i = 0
var j = 0
propagator(
both = { a, b -> if (a == null) b else if (b == null) a
// Only periodically minimize the automata during construction
else if (i++ % 13 == 0) a.concatenate(b).min() else a.concatenate(b) },
either = { a, b -> if (a == null) b else if (b == null) a
else if (j++ % 13 == 0) a.union(b).min() else a.union(b) },
unit = { a ->
if ("ε" in a.root) null
else BAutomaton.makeChar(Random(a.root.hashCode()).nextInt().toChar())
}
)
}.also { println("Took ${it.duration} to build FSA") }.value
?.also { println("Original automata had ${it
.let { "${it.numberOfStates} states and ${it.numberOfTransitions} transitions"}}") }
?.also {
measureTimedValue { BAutomaton.setMinimization(MINIMIZE_BRZOZOWSKI); BAutomaton.minimize(it) }
if (minimize) measureTimedValue { BAutomaton.minimize(it) }
.also { println("Minimization took ${it.duration}") }.value
// .also { it.toDot().replaceAll(stbl).alsoCopy() }
.also {
Expand All @@ -83,48 +88,66 @@ fun PTree.decodeDFA(mc: MarkovChain<Σᐩ>, topK: Int = 10_000_000): List<Σᐩ>
}")
}
}
// ?.getFiniteStrings(-1)?.map { it.map { ctbl[it] }.joinToString(" ") } ?: emptySet()
?.steerableRandomWalk(
mc = mc,
dec = allTerminals.associateBy { Random(it.hashCode()).nextInt().toChar() },
topK = topK
) ?: emptyList()

// Steers a random walk using the last n-1 transitions from the Markov Chain
fun BAutomaton.steerableRandomWalk(
fun BAutomaton.decodeDFA(
mc: MarkovChain<Σᐩ>,
// BAutomata uses a Unicode alphabet, and the Markov Chain recognizes a
// string-based alphabet, so we need a way to translate between the two
dec: Map<Char, String>, // Maps unicode characters back to strings
topK: Int // Total number of top-K results to return
dec: Map<Char, Σᐩ>, // Maps unicode characters back to strings because BAutomata uses Unicode
callback: (Σᐩ) -> Unit = {},
topK: Int = 10_000_000, // Total number of top-K results to return
timeout: Duration = Duration.INFINITE,
parallelize: Boolean = false
): List<Σᐩ> {
val startTime = TimeSource.Monotonic.markNow()
val fullTrajectories = PriorityQueue<FSATrajectory>(compareBy { it.score / it.toks.size })
val partTrajectories = PriorityQueue<FSATrajectory>(compareBy { it.score / it.toks.size })
.apply { add(FSATrajectory(List(mc.memory) { null }, initialState, 0.0)) }
while (fullTrajectories.size < topK && partTrajectories.isNotEmpty()) {
val partTraj = partTrajectories.remove()
val lastToks = partTraj.toks.take(mc.memory - 1).reversed()
partTraj.lastState.transitions.forEach { next: Transition ->
(next.min..next.max).forEach { tok ->
val decTok = dec[tok]
val nextToks = lastToks + decTok
val nextScore = partTraj.score + mc.scoreChunk(nextToks)
val traj = FSATrajectory(listOf(decTok) + partTraj.toks, next.dest, nextScore)
if (!traj.isComplete) partTrajectories.add(traj)
else {
fullTrajectories.add(traj)
if (traj.lastState.transitions.isNotEmpty())
partTrajectories.add(traj)
val load = 100_000
val fullTrajectories = PriorityBlockingQueue<FSATrajectory>(load, compareBy { it.score / it.traj.size })
val partTrajectories = Array(if(parallelize) NUM_CORES else 1) {
PriorityBlockingQueue<FSATrajectory>(load, compareBy { it.score / it.traj.size })
.apply { add(FSATrajectory(List(mc.memory) { null }, initialState, 0.0)) }
}

fun task(id: Int = 0) {
var i = 0
while (
fullTrajectories.size < topK &&
partTrajectories.any { it.size > 0 } &&
startTime.elapsedNow() < timeout
) {
if (partTrajectories[id].isEmpty()) continue
// Checks for balanced distribution of work across cores
// if (i++ % 9999 == 0) println("Trajectories[$id]: ${partTrajectories.map {it.size}}")
val partTraj = partTrajectories[id].remove()
val lastToks = partTraj.traj.take(mc.memory - 1).reversed()
partTraj.lastState.transitions.forEach { next: Transition ->
(next.min..next.max).forEach { tok ->
val decTok = dec[tok]
val nextToks = lastToks + decTok
val nextScore = partTraj.score + mc.scoreChunk(nextToks)
val traj = FSATrajectory(listOf(decTok) + partTraj.traj, next.dest, nextScore)
val bin = if (parallelize) Random(traj.score.hashCode()).nextInt(NUM_CORES) else 0
if (!traj.isComplete) partTrajectories[bin].add(traj)
else {
fullTrajectories.add(traj)
callback(traj.toString())
if (traj.lastState.transitions.isNotEmpty())
partTrajectories[bin].add(traj)
}
}
}
}
}

println("Top 10 trajectories:")
fullTrajectories.take(10).forEach { println(it.score.toString().take(5) + ": $it") }
println("Took ${startTime.elapsedNow()} to decode ${fullTrajectories.size} trajectories")
if (parallelize) (0..<NUM_CORES).toList().parallelStream().forEach { task(it) } else task(0)

return fullTrajectories.map { it.toString() }
}
// Deduplicate and resort by final score
val deduped = fullTrajectories.parallelStream().map { it.toString() to mc.score(it.tokens) }
.distinct().toList().sortedBy { it.second }.map { it.first }

// println("Top 10 trajectories:")
// fullTrajectories.take(10).forEach { println(it.score.toString().take(5) + ": $it") }
println("Took ${startTime.elapsedNow()} to decode ${fullTrajectories.size} trajectories")

return deduped
}
11 changes: 8 additions & 3 deletions src/jvmMain/kotlin/ai/hypergraph/markovian/mcmc/MarkovChain.kt
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ val maxUniques: Int = 2000
open class MarkovChain<T>(
train: Sequence<T> = sequenceOf(),
val memory: Int = 3,
val counter: Counter<T> = Counter(train, memory)
val counter: Counter<T> = Counter(train, memory),
var scorePrefix: List<T> = listOf(),
var scoreSuffix: List<T> = listOf()
) {
private val mgr = ResettableLazyManager()

Expand Down Expand Up @@ -120,7 +122,7 @@ open class MarkovChain<T>(

// Computes perplexity of a sequence normalized by sequence length (lower is better)
fun score(seq: List<T?>): Double =
if (memory < seq.size) -seq.windowed(memory)
if (memory < seq.size) -(scorePrefix + seq + scoreSuffix).windowed(memory)
.map { (getAtLeastOne(it) + 1) / (getAtLeastOne(it.dropLast(1) + null) + dictionary.size) }
.sumOf { ln(it) } / seq.size
else (seq.sumOf { counter.rawCounts.getEstimate(it) } + 1).toDouble() / counter.total.toDouble()
Expand All @@ -146,7 +148,10 @@ open class MarkovChain<T>(
}

// https://www.cs.utah.edu/~jeffp/papers/merge-summ.pdf
operator fun plus(mc: MarkovChain<T>) = MarkovChain<T>(memory = memory, counter = counter + mc.counter)
operator fun plus(mc: MarkovChain<T>) = MarkovChain<T>(
memory = memory, counter = counter + mc.counter,
scorePrefix = scorePrefix, scoreSuffix = scoreSuffix
)

/**
* TODO: construct [Dist] using precomputed normalization constants [Counter.nrmCounts]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class WFSATest {
// readBIFIContents()
val csv = File(File("").absolutePath + "/src/jvmTest/resources/ngrams_BIFI_$MARKOV_MEMORY.csv")
MarkovChain.deserialize(csv.readText())
.apply { scorePrefix = listOf("BOS", "NEWLINE"); scoreSuffix = listOf("EOS") }
.also { println("Loaded ${it.counter.total} BIFI $MARKOV_MEMORY-grams from ${csv.absolutePath}") }
}

Expand All @@ -29,6 +30,7 @@ class WFSATest {
val P_PY150: MarkovChain<Σᐩ> by lazy {
val csv = File(File("").absolutePath + "/src/jvmTest/resources/ngrams_PY150_$MARKOV_MEMORY.csv")
MarkovChain.deserialize(csv.readText())
.apply { scorePrefix = listOf("BOS", "NEWLINE"); scoreSuffix = listOf("EOS") }
.also { println("Loaded ${it.counter.total} PY150 $MARKOV_MEMORY-grams from ${csv.absolutePath}") }
}

Expand Down Expand Up @@ -88,7 +90,7 @@ class WFSATest {
val ptreeRepairs = measureTimedValue {
pt.sampleStrWithoutReplacement().distinct().take(maxResults).toSet()
}
measureTimedValue { pt.decodeDFA(P_BIFI_PY150) }.also {
measureTimedValue { pt.toDFA()!!.decodeDFA(P_BIFI_PY150, dec = pt.termDict, parallelize = true) }.also {
assertTrue(groundTr in it.value, "Ground truth not found in ${it.value.size} repairs")
println("Index: ${it.value.indexOf(groundTr)}")
// // Print side by side comparison of repairs
Expand Down

0 comments on commit 4f0e94a

Please sign in to comment.