Skip to content

Commit

Permalink
Update for HMCMC telemetry output
Browse files Browse the repository at this point in the history
  • Loading branch information
gvonness committed Aug 2, 2023
1 parent f0c1a6a commit f548575
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 25 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ThisBuild / baseVersion := "0.7.0"
ThisBuild / baseVersion := "0.8.0"

ThisBuild / organization := "ai.entrolution"
ThisBuild / organizationName := "Greg von Nessi"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ import scala.collection.immutable.Queue

case class HmcmcSampledPosterior[F[_]: STM: Async](
private[thylacine] val hmcmcConfig: HmcmcConfig,
protected override val hamiltonianDifferentialUpdateCallback: Double => F[Unit],
protected override val sampleProcessedCallback: HmcmcTelemetryUpdate => F[Unit],
protected override val telemetryUpdateCallback: HmcmcTelemetryUpdate => F[Unit],
private[thylacine] val seed: Map[String, Vector[Double]],
private[thylacine] override val priors: Set[Prior[F, _]],
private[thylacine] override val likelihoods: Set[Likelihood[F, _, _]],
Expand Down Expand Up @@ -81,8 +80,7 @@ object HmcmcSampledPosterior {
def of[F[_]: STM: Async](
hmcmcConfig: HmcmcConfig,
posterior: Posterior[F, Prior[F, _], Likelihood[F, _, _]],
hamiltonianDifferentialUpdateCallback: Double => F[Unit],
sampleProcessedCallback: HmcmcTelemetryUpdate => F[Unit],
telemetryUpdateCallback: HmcmcTelemetryUpdate => F[Unit],
seed: Map[String, Vector[Double]]
): F[HmcmcSampledPosterior[F]] =
for {
Expand All @@ -94,18 +92,17 @@ object HmcmcSampledPosterior {
jumpAttempts <- TxnVar.of(0)
posterior <- Async[F].delay {
HmcmcSampledPosterior(
hmcmcConfig = hmcmcConfig,
hamiltonianDifferentialUpdateCallback = hamiltonianDifferentialUpdateCallback,
sampleProcessedCallback = sampleProcessedCallback,
seed = seed,
priors = posterior.priors,
likelihoods = posterior.likelihoods,
currentMcmcPositions = currentMcmcPositions,
burnInComplete = burnInComplete,
workTokenPool = workTokenPool,
jumpAcceptances = jumpAcceptances,
jumpAttempts = jumpAttempts,
numberOfSamplesRemaining = numberOfSamplesRemaining
hmcmcConfig = hmcmcConfig,
telemetryUpdateCallback = telemetryUpdateCallback,
seed = seed,
priors = posterior.priors,
likelihoods = posterior.likelihoods,
currentMcmcPositions = currentMcmcPositions,
burnInComplete = burnInComplete,
workTokenPool = workTokenPool,
jumpAcceptances = jumpAcceptances,
jumpAttempts = jumpAttempts,
numberOfSamplesRemaining = numberOfSamplesRemaining
)
}
_ <- posterior.launchInitialisation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,22 @@ package thylacine.model.core.telemetry
case class HmcmcTelemetryUpdate(
samplesRemaining: Int,
jumpAttempts: Int,
jumpAcceptances: Int
jumpAcceptances: Int,
hamiltonianDifferential: Option[Double]
) extends TelemetryReport {

override lazy val logMessage: String =
s"HMCMC Sampling :: Samples remaining - $samplesRemaining // Acceptance Ratio - ${jumpAcceptances.toDouble / jumpAttempts}"
override lazy val logMessage: String = {
val baseString =
s"HMCMC Sampling :: Samples remaining - $samplesRemaining // Acceptance Ratio - ${jumpAcceptances.toDouble / jumpAttempts}"
val acceptanceString =
if (jumpAttempts != 0) {
s" // Acceptance Ratio - ${jumpAcceptances.toDouble / jumpAttempts}"
} else {
""
}
val hamiltonianDifferentialString =
hamiltonianDifferential.map(v => s" // exp(-dH) = ${Math.exp(-v)}").getOrElse("")

baseString + acceptanceString + hamiltonianDifferentialString
}
}
19 changes: 14 additions & 5 deletions src/main/scala/thylacine/model/sampling/hmcmc/HmcmcEngine.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ private[thylacine] trait HmcmcEngine[F[_]] extends ModelParameterSampler[F] {
protected def warmUpSimulationCount: Int

protected def startingPoint: F[ModelParameterCollection]
protected def hamiltonianDifferentialUpdateCallback: Double => F[Unit]
protected def sampleProcessedCallback: HmcmcTelemetryUpdate => F[Unit]
protected def telemetryUpdateCallback: HmcmcTelemetryUpdate => F[Unit]

/*
* - - -- --- ----- -------- -------------
Expand Down Expand Up @@ -148,7 +147,16 @@ private[thylacine] trait HmcmcEngine[F[_]] extends ModelParameterSampler[F] {
_ <- if (burnIn) {
Async[F].unit
} else {
hamiltonianDifferentialUpdateCallback(dH)
(for {
remainingSamples <- numberOfSamplesRemaining.get
jumps <- jumpAcceptances.get
attempts <- jumpAttempts.get
} yield HmcmcTelemetryUpdate(
samplesRemaining = remainingSamples,
jumpAttempts = attempts,
jumpAcceptances = jumps,
hamiltonianDifferential = Option(dH)
)).commit.flatMap(telemetryUpdateCallback)
}
result <- Async[F].ifM(Async[F].delay(dH < 0 || Math.random() < Math.exp(-dH)))(
for {
Expand Down Expand Up @@ -242,8 +250,9 @@ private[thylacine] trait HmcmcEngine[F[_]] extends ModelParameterSampler[F] {
jumps <- jumpAcceptances.get
attempts <- jumpAttempts.get
} yield (jumps, attempts)).commit
telemetryResult <- Async[F].delay(HmcmcTelemetryUpdate(numberOfSamples, jumpsAndAttempts._2, jumpsAndAttempts._1))
_ <- sampleProcessedCallback(telemetryResult)
telemetryResult <-
Async[F].delay(HmcmcTelemetryUpdate(numberOfSamples, jumpsAndAttempts._2, jumpsAndAttempts._1, None))
_ <- telemetryUpdateCallback(telemetryResult)
} yield ()

private val getHmcmcSample: F[ModelParameterCollection] =
Expand Down

0 comments on commit f548575

Please sign in to comment.