From b9134ff509ebd29e8d2f34a282b37152703f1d7d Mon Sep 17 00:00:00 2001 From: breandan Date: Sun, 7 Jul 2024 15:23:50 -0400 Subject: [PATCH] evaluate PCFG NLL reranker --- .../kaliningraph/parsing/BarHillel.kt | 2 +- .../kaliningraph/parsing/Normalization.kt | 3 +- .../kaliningraph/parsing/SeqValiant.kt | 33 +++++++++++++++++-- .../hypergraph/kaliningraph/parsing/Tree.kt | 7 ++-- .../kaliningraph/parsing/JVMBarHillel.kt | 12 +++++++ 5 files changed, 50 insertions(+), 7 deletions(-) diff --git a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillel.kt b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillel.kt index f9f04c55..942f578c 100644 --- a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillel.kt +++ b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/BarHillel.kt @@ -27,7 +27,7 @@ fun CFG.barHillelRepair(prompt: List<Σᐩ>, distance: Int) = // http://www.cs.umd.edu/~gasarch/BLOGPAPERS/cfg.pdf#page=2 // https://browse.arxiv.org/pdf/2209.06809.pdf#page=5 -private fun CFG.intersectLevFSAP(fsa: FSA, parikhMap: ParikhMap = this.parikhMap): CFG { +fun CFG.intersectLevFSAP(fsa: FSA, parikhMap: ParikhMap = this.parikhMap): CFG { var clock = TimeSource.Monotonic.markNow() val nts = mutableSetOf("START") fun Σᐩ.isSyntheticNT() = diff --git a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/Normalization.kt b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/Normalization.kt index 12b0cd5f..b45023f4 100644 --- a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/Normalization.kt +++ b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/Normalization.kt @@ -43,7 +43,8 @@ fun CFG.transformIntoCNF(): CFG = addEpsilonProduction() .refactorEpsilonProds() .elimVarUnitProds() - .binarizeRHSByFrequency() +// .binarizeRHSByFrequency() + .binarizeRHSByRightmost() .terminalsToUnitProds() .removeUselessSymbols() diff --git a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/SeqValiant.kt b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/SeqValiant.kt index ef414e83..60f7fd44 100644 --- a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/SeqValiant.kt +++ b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/SeqValiant.kt @@ -7,10 +7,10 @@ import ai.hypergraph.kaliningraph.tensor.UTMatrix import ai.hypergraph.kaliningraph.types.* import com.ionspin.kotlin.bignum.integer.* import kotlin.jvm.JvmName +import kotlin.math.ln import kotlin.random.* import kotlin.time.measureTimedValue - // Indexes a set of PTrees by their roots typealias PForest = Map // ℙ₃ // Algebraic data type / polynomial functor for parse forests (ℙ₂) @@ -108,6 +108,20 @@ class PTree(val root: String = ".ε", val branches: List<Π2A> = listOf() return if (left.isEmpty()) right else if (right.isEmpty()) left else "$left $right" } + private fun newDecoderWithProb(i: BigInteger, pcfgMap: Map<Π3A<Σᐩ>, Int>, pcfgNorm: Map<Σᐩ, Int>): Pair { + if (branches.isEmpty()) return epsStr to 0.0 + val t = ranges.indexOfFirst { it.first <= i && i <= it.second } + val (l, r) = branches[t] + val q = i - ranges[t].first + val (iLeft, iRight) = q.divrem(r.totalTrees) + val (lroot, rroot) = l.rootName to r.rootName + val (left, leftScore) = l.newDecoderWithProb(iLeft, pcfgMap, pcfgNorm) + val (right, rightScore) = r.newDecoderWithProb(iRight, pcfgMap, pcfgNorm) + val myScore = ln((pcfgMap[root to lroot to rroot]?.toDouble() ?: 0.00001) / (pcfgNorm[root]?.toDouble() ?: 1.0)) + + leftScore + rightScore + return (if (left.isEmpty()) right else if (right.isEmpty()) left else "$left $right") to myScore + } + // Average time: 436.96ms, total time 43696.959ms (testRandomCFG) private fun decodeString(i: BigInteger): Pair { if (branches.isEmpty()) return epsStr to i @@ -154,6 +168,20 @@ class PTree(val root: String = ".ε", val branches: List<Π2A> = listOf() while (i < totalTrees) { yield(newDecoder(i)); i++} } + // Returns trees WoR from the CFG and scores the strings with a PCFG-based log-likelihood + fun sampleStrWithoutReplacementAndScore( + stride: Int = 1, offset: Int = 0, + pcfgMap: Map<Π3A<Σᐩ>, Int>, pcfgNorm: Map<Σᐩ, Int> + ): Sequence<Π2> = + if (6 < totalTrees.bitLength()) + bigLFSRSequence(totalTrees).mapIndexedNotNull { index, i -> + if (index % stride == offset) newDecoderWithProb(i, pcfgMap, pcfgNorm) else null + } + else sequence { + var i = BigInteger.ZERO + while (i < totalTrees) { yield(newDecoderWithProb(i, pcfgMap, pcfgNorm)); i++} + } + fun sampleStrWithPCFG5(pcfgTable: Map): Sequence = sequence { while (true) yield(samplePCFG5(pcfgTable)) } @@ -186,6 +214,7 @@ class PTree(val root: String = ".ε", val branches: List<Π2A> = listOf() if (a.isEmpty()) b else if (b.isEmpty()) a else "$a $b" } + /** See [intersectLevFSAP], extracts original NT name from a synthetic ∩-NT. */ fun Σᐩ.name() = if ('~' in this) split('~')[1] else this val triples : List<Π2A> by lazy { branches.map { it.first.ntIdx to it.second.ntIdx } } val rootName by lazy { root.name() } @@ -199,7 +228,7 @@ class PTree(val root: String = ".ε", val branches: List<Π2A> = listOf() // .also { if(Random.nextInt(10000) == 3) if (it == 1) println("$hash Miss"); else println("$hash Hit") } + 1 } val cdf = probs.runningReduce { acc, i -> acc + i } - val rnd = Random.nextInt(probs.sum()) + val rnd = Random.nextInt(cdf.last()) val childIdx = cdf.binarySearch { it.compareTo(rnd) }.let { if (it < 0) -it - 1 else it } val (l, r) = branches[childIdx] val (lr, rr) = l.ntIdx to r.ntIdx diff --git a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/Tree.kt b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/Tree.kt index 6e626d5e..c0d6f47b 100644 --- a/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/Tree.kt +++ b/src/commonMain/kotlin/ai/hypergraph/kaliningraph/parsing/Tree.kt @@ -4,6 +4,7 @@ import ai.hypergraph.kaliningraph.graphs.LGVertex import ai.hypergraph.kaliningraph.graphs.LabeledGraph import ai.hypergraph.kaliningraph.tensor.FreeMatrix import ai.hypergraph.kaliningraph.types.* +import kotlin.math.ln typealias TreeMatrix = FreeMatrix typealias Forest = Set @@ -39,10 +40,10 @@ class Tree constructor( children[0].quintuples(root, children[0].root + "*", children[1].root) + children[1].quintuples(root, children[0].root, children[1].root + "*") - fun logProb(pcfgMap: Map<Π3A<Σᐩ>, Int>): Double = + fun logProb(pcfgMap: Map<Π3A<Σᐩ>, Int>, pcfgNorm: Map<Σᐩ, Int>): Double = if (children.isEmpty()) 0.0 - else (pcfgMap[root to children[0].root to children[1].root]?.toDouble() ?: 0.0) + - children.sumOf { it.logProb(pcfgMap) } + else ln((pcfgMap[root to children[0].root to children[1].root]?.toDouble() ?: 0.00001) / (pcfgNorm[root]?.toDouble() ?: 1.0)) + + children.sumOf { it.logProb(pcfgMap, pcfgNorm) } fun toGraph(j: Σᐩ = "0"): LabeledGraph = LabeledGraph { LGVertex(root, "$root.$j").let { it - it } } + diff --git a/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt b/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt index 03cce49d..753c0417 100644 --- a/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt +++ b/src/jvmMain/kotlin/ai/hypergraph/kaliningraph/parsing/JVMBarHillel.kt @@ -99,6 +99,18 @@ fun PTree.sampleDirectlyWOR( .asStream() } +fun PTree.sampleDirectlyWORAndScore( + cores: Int = NUM_CORES, + stoppingCriterion: () -> Boolean = { true }, + pcfgMap: Map<Π3A<Σᐩ>, Int>, pcfgNorm: Map<Σᐩ, Int> +): Stream<Π2> = + (0.. + sampleStrWithoutReplacementAndScore(cores, i, pcfgMap, pcfgNorm) + .takeWhile { stoppingCriterion() } + .distinctBy { it.first } + .asStream() + } + fun CFG.parallelEnumListWR( prompt: List, cores: Int = NUM_CORES,