Skip to content

Commit

Permalink
Refactor to use F instead of thunks to handle suspended effects
Browse files Browse the repository at this point in the history
  • Loading branch information
gvonness committed Sep 6, 2022
1 parent 5da96d5 commit 06dafa9
Show file tree
Hide file tree
Showing 16 changed files with 668 additions and 535 deletions.
48 changes: 27 additions & 21 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ThisBuild / baseVersion := "0.6.1"
ThisBuild / baseVersion := "0.7.0"

ThisBuild / organization := "ai.entrolution"
ThisBuild / organizationName := "Greg von Nessi"
Expand Down
15 changes: 9 additions & 6 deletions src/main/scala/bengal/stm/STM.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,21 @@ package bengal.stm
import bengal.stm.api.internal.TxnApiContext
import bengal.stm.model._
import bengal.stm.model.runtime._
import bengal.stm.runtime.TxnRuntimeContext
import bengal.stm.runtime.{TxnCompilerContext, TxnLogContext, TxnRuntimeContext}

import cats.effect.Ref
import cats.effect.implicits._
import cats.effect.implicits.genSpawnOps
import cats.effect.kernel.{Async, Deferred}
import cats.effect.std.Semaphore
import cats.implicits._

import scala.concurrent.duration.{FiniteDuration, NANOSECONDS}

trait STM[F[_]]
extends TxnRuntimeContext[F]
abstract class STM[F[_]: Async]
extends AsyncImplicits[F]
with TxnRuntimeContext[F]
with TxnCompilerContext[F]
with TxnLogContext[F]
with TxnApiContext[F]
with TxnAdtContext[F] {

Expand Down Expand Up @@ -123,12 +126,12 @@ object STM {
}

override def allocateTxnVar[V](value: V): F[TxnVar[F, V]] =
TxnVar.of(value)(this, Async[F])
TxnVar.of(value)(this, this.asyncF)

override def allocateTxnVarMap[K, V](
valueMap: Map[K, V]
): F[TxnVarMap[F, K, V]] =
TxnVarMap.of(valueMap)(this, Async[F])
TxnVarMap.of(valueMap)(this, this.asyncF)

override private[stm] def commitTxn[V](txn: Txn[V]): F[V] =
txnRuntime.commit(txn)
Expand Down
86 changes: 77 additions & 9 deletions src/main/scala/bengal/stm/api/internal/TxnApiContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ package bengal.stm.api.internal

import bengal.stm.model.TxnErratum._
import bengal.stm.model._
import bengal.stm.runtime.TxnRuntimeContext

import cats.effect.kernel.Async
import cats.free.Free

import scala.util.{Failure, Success, Try}

private[stm] trait TxnApiContext[F[_]] {
this: TxnAdtContext[F] =>
this: AsyncImplicits[F] with TxnRuntimeContext[F] with TxnAdtContext[F] =>

private def liftSuccess[V](txnAdt: TxnAdt[V]): Txn[V] =
Free.liftF[TxnOrErr, V](Right(txnAdt))
Expand All @@ -36,8 +38,11 @@ private[stm] trait TxnApiContext[F[_]] {
val unit: Txn[Unit] =
liftSuccess(TxnUnit)

def fromF[V](spec: F[V]): Txn[V] =
liftSuccess(TxnDelay(spec))

def delay[V](thunk: => V): Txn[V] =
liftSuccess(TxnDelay(() => thunk))
liftSuccess(TxnDelay(Async[F].delay(thunk)))

def pure[V](value: V): Txn[V] =
liftSuccess(TxnPure(value))
Expand All @@ -58,7 +63,12 @@ private[stm] trait TxnApiContext[F[_]] {
private[stm] def handleErrorWithInternal[V](fa: => Txn[V])(
f: Throwable => Txn[V]
): Txn[V] =
liftSuccess(TxnHandleError(() => fa, f))
liftSuccess(TxnHandleError(Async[F].delay(fa), ex => Async[F].delay(f(ex))))

private[stm] def handleErrorWithInternalF[V](fa: => Txn[V])(
f: Throwable => F[Txn[V]]
): Txn[V] =
liftSuccess(TxnHandleError(Async[F].delay(fa), f))

private[stm] def getTxnVar[V](txnVar: TxnVar[F, V]): Txn[V] =
liftSuccess(TxnGetVar(txnVar))
Expand All @@ -67,14 +77,29 @@ private[stm] trait TxnApiContext[F[_]] {
newValue: => V,
txnVar: TxnVar[F, V]
): Txn[Unit] =
liftSuccess(TxnSetVar(() => newValue, txnVar))
liftSuccess(TxnSetVar(Async[F].delay(newValue), txnVar))

private[stm] def setTxnVarF[V](
newValue: F[V],
txnVar: TxnVar[F, V]
): Txn[Unit] =
liftSuccess(TxnSetVar(newValue, txnVar))

private[stm] def modifyTxnVar[V](f: V => V, txnVar: TxnVar[F, V]): Txn[Unit] =
for {
value <- getTxnVar(txnVar)
_ <- setTxnVar(f(value), txnVar)
} yield ()

private[stm] def modifyTxnVarF[V](
f: V => F[V],
txnVar: TxnVar[F, V]
): Txn[Unit] =
for {
value <- getTxnVar(txnVar)
_ <- setTxnVarF(f(value), txnVar)
} yield ()

private[stm] def getTxnVarMap[K, V](
txnVarMap: TxnVarMap[F, K, V]
): Txn[Map[K, V]] =
Expand All @@ -84,7 +109,13 @@ private[stm] trait TxnApiContext[F[_]] {
newValueMap: => Map[K, V],
txnVarMap: TxnVarMap[F, K, V]
): Txn[Unit] =
liftSuccess(TxnSetVarMap(() => newValueMap, txnVarMap))
liftSuccess(TxnSetVarMap(Async[F].delay(newValueMap), txnVarMap))

private[stm] def setTxnVarMapF[K, V](
newValueMap: F[Map[K, V]],
txnVarMap: TxnVarMap[F, K, V]
): Txn[Unit] =
liftSuccess(TxnSetVarMap(newValueMap, txnVarMap))

private[stm] def modifyTxnVarMap[K, V](
f: Map[K, V] => Map[K, V],
Expand All @@ -95,29 +126,66 @@ private[stm] trait TxnApiContext[F[_]] {
_ <- setTxnVarMap(f(value), txnVarMap)
} yield ()

private[stm] def modifyTxnVarMapF[K, V](
f: Map[K, V] => F[Map[K, V]],
txnVarMap: TxnVarMap[F, K, V]
): Txn[Unit] =
for {
value <- getTxnVarMap(txnVarMap)
_ <- setTxnVarMapF(f(value), txnVarMap)
} yield ()

private[stm] def getTxnVarMapValue[K, V](
key: => K,
txnVarMap: TxnVarMap[F, K, V]
): Txn[Option[V]] =
liftSuccess(TxnGetVarMapValue(() => key, txnVarMap))
liftSuccess(TxnGetVarMapValue(Async[F].delay(key), txnVarMap))

private[stm] def setTxnVarMapValue[K, V](
key: => K,
newValue: => V,
txnVarMap: TxnVarMap[F, K, V]
): Txn[Unit] =
liftSuccess(TxnSetVarMapValue(() => key, () => newValue, txnVarMap))
liftSuccess(
TxnSetVarMapValue(Async[F].delay(key),
Async[F].delay(newValue),
txnVarMap
)
)

private[stm] def setTxnVarMapValueF[K, V](
key: => K,
newValue: F[V],
txnVarMap: TxnVarMap[F, K, V]
): Txn[Unit] =
liftSuccess(
TxnSetVarMapValue(Async[F].delay(key), newValue, txnVarMap)
)

private[stm] def modifyTxnVarMapValue[K, V](
key: => K,
f: V => V,
txnVarMap: TxnVarMap[F, K, V]
): Txn[Unit] =
liftSuccess(TxnModifyVarMapValue(() => key, f, txnVarMap))
liftSuccess(
TxnModifyVarMapValue(Async[F].delay(key),
(v: V) => Async[F].delay(f(v)),
txnVarMap
)
)

private[stm] def modifyTxnVarMapValueF[K, V](
key: => K,
f: V => F[V],
txnVarMap: TxnVarMap[F, K, V]
): Txn[Unit] =
liftSuccess(
TxnModifyVarMapValue(Async[F].delay(key), f, txnVarMap)
)

private[stm] def removeTxnVarMapValue[K, V](
key: => K,
txnVarMap: TxnVarMap[F, K, V]
): Txn[Unit] =
liftSuccess(TxnDeleteVarMapValue(() => key, txnVarMap))
liftSuccess(TxnDeleteVarMapValue(Async[F].delay(key), txnVarMap))
}
24 changes: 24 additions & 0 deletions src/main/scala/bengal/stm/model/AsyncImplicits.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright 2020-2022 Greg von Nessi
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.entrolution
package bengal.stm.model

import cats.effect.kernel.Async

private[stm] abstract class AsyncImplicits[F[_]](
protected implicit val asyncF: Async[F]
)
22 changes: 11 additions & 11 deletions src/main/scala/bengal/stm/model/TxnAdtContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,49 +21,49 @@ private[stm] trait TxnAdtContext[F[_]] {

private[stm] case object TxnUnit extends TxnAdt[Unit]

private[stm] case class TxnDelay[V](thunk: () => V) extends TxnAdt[V]
private[stm] case class TxnDelay[V](thunk: F[V]) extends TxnAdt[V]

private[stm] case class TxnPure[V](value: V) extends TxnAdt[V]

private[stm] case class TxnGetVar[V](txnVar: TxnVar[F, V]) extends TxnAdt[V]

private[stm] case class TxnSetVar[V](
newValue: () => V,
newValue: F[V],
txnVar: TxnVar[F, V]
) extends TxnAdt[Unit]

private[stm] case class TxnGetVarMap[K, V](txnVarMap: TxnVarMap[F, K, V])
extends TxnAdt[Map[K, V]]

private[stm] case class TxnGetVarMapValue[K, V](
key: () => K,
key: F[K],
txnVarMap: TxnVarMap[F, K, V]
) extends TxnAdt[Option[V]]

private[stm] case class TxnSetVarMap[K, V](
newMap: () => Map[K, V],
newMap: F[Map[K, V]],
txnVarMap: TxnVarMap[F, K, V]
) extends TxnAdt[Unit]

private[stm] case class TxnSetVarMapValue[K, V](
key: () => K,
newValue: () => V,
key: F[K],
newValue: F[V],
txnVarMap: TxnVarMap[F, K, V]
) extends TxnAdt[Unit]

private[stm] case class TxnModifyVarMapValue[K, V](
key: () => K,
f: V => V,
key: F[K],
f: V => F[V],
txnVarMap: TxnVarMap[F, K, V]
) extends TxnAdt[Unit]

private[stm] case class TxnDeleteVarMapValue[K, V](
key: () => K,
key: F[K],
txnVarMap: TxnVarMap[F, K, V]
) extends TxnAdt[Unit]

private[stm] case class TxnHandleError[V](
fa: () => Txn[V],
f: Throwable => Txn[V]
fa: F[Txn[V]],
f: Throwable => F[Txn[V]]
) extends TxnAdt[V]
}
1 change: 0 additions & 1 deletion src/main/scala/bengal/stm/model/TxnErratum.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package bengal.stm.model
private[stm] sealed trait TxnErratum

object TxnErratum {
private[stm] case object NoErratum extends TxnErratum

private[stm] case object TxnRetry extends TxnErratum

Expand Down
5 changes: 2 additions & 3 deletions src/main/scala/bengal/stm/model/TxnVar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package bengal.stm.model
import bengal.stm.STM
import bengal.stm.model.runtime._

import cats.effect.implicits._
import cats.effect.kernel.Async
import cats.effect.std.Semaphore
import cats.effect.{Deferred, Ref}
Expand All @@ -33,10 +32,10 @@ case class TxnVar[F[_]: Async, T](
private[stm] val txnRetrySignals: TxnSignals[F]
) extends TxnStateEntity[F, T] {

private def completeRetrySignals: F[Unit] =
private[stm] def completeRetrySignals: F[Unit] =
for {
signals <- txnRetrySignals.getAndSet(Set())
_ <- signals.toList.parTraverse(_.complete(()))
_ <- signals.toList.traverse(_.complete(()))
} yield ()

private[stm] lazy val get: F[T] =
Expand Down
Loading

0 comments on commit 06dafa9

Please sign in to comment.