diff --git a/README.md b/README.md index 3774bf8..78782c5 100644 --- a/README.md +++ b/README.md @@ -19,27 +19,33 @@ As there are already many solid references on STM, I will not dive into STM theo ## API -Example | Description | Type Signature | Notes -:--- | --- | :--- | :--- -`STM.runtime[F]` | Creates a runtime in an `F[_]` container whose transaction results can be lifted into a container `F[_]` via `commit` | `def runtime[F[_]: Async](retryMaxWait: FiniteDuration, maxWaitingToProcessInLoop: Int): F[STM[F]]`
or
`def runtime[F[_]: Async]: F[STM[F]]` (default `retryMaxWait`) | `retryMaxWait` is a backstop max amount of time to wait before retrying a transaction.

Default: `FiniteDuration(Long.MaxValue, NANOSECONDS)`

It is _not_ recommended to make this a small value (i.e. making retries effectively based on polling).

`maxWaitingToProcessInLoop` corresponds to max amount of waiting transactions the runtime will attempt to process in its runtime loop. It is not recommended to alter this value.

Default: `Runtime.getRuntime.availableProcessors() * 2` -`txnVar.get.commit` | Commits a transaction and lifts the result into `F[_]` | `def commit: F[V]` | -`TxnVar.of[List[Int]](List())` | Creates a transactional variable | ```def of[T](value: T): F[TxnVar[T]]``` -`TxnVarMap.of[String, Int](Map())` | Creates a transactional map | ```of[K, V](valueMap: Map[K, V]): F[TxnVarMap[K, V]]``` -`txnVar.get` | Retrieves value of transactional variable | ```def get: Txn[V]``` | -`txnVarMap.get` | Retrieves an immutable map (i.e. a view) representing transactional map state | ```def get: Txn[Map[K, V]]``` | Performance-wise it is better to retrieve individual keys instead of acquiring the entire map -`txnVarMap.get("David")` | Retrieves optional value depending on whether key exists in the map | ```def get(key: K): Txn[Option[V]]``` | Will raise an error if the key is never created (previously or current transaction). A `None` is returned if the value has been deleted in the current transaction. -`txnVar.set(100)` | Sets the value of transactional variable | ``` def set(newValue: V): Txn[Unit]``` -`txnVarMap.set(Map("David" -> 100))` | Uses an immutable map to set the transactional map state | ```def set(newValueMap: Map[K, V]): Txn[Unit]``` | Performance-wise it is better to set individual keys instead of setting the entire map in this manner.

This operation will create/delete key-values as needed to update the state of the map. -`txnVarMap.set("David", 100)` | Upserts the key-value into the transactional map | ```def set(key: K, newValue: V): Txn[Unit]``` | Will create the key-value in the transactional map, if the key was not present -`txnVar.modify(_ + 5)` | Modifies the value of a transactional variable | ```def modify(f: V => V): Txn[Unit]``` -`txnVarMap.modify("David", _ + 20)` | Modifies the value in a transactional map for a given key | ```def modify(key: K, f: V => V): Txn[Unit]``` | Will throw an error if the `key` is not present in the map (or has been deleted in the current transaction) -`txnVarMap.modify(_.map(i => i._1 -> i._2*2))` | Modifies all the values in the map | ```def modify(f: Map[K, V] => Map[K, V]): Txn[Unit]``` | Transform can create/delete entries.

Again, for performance it is better to work with individual key-value pairs instead of manipulating map views -`txnVarMap.remove("David")` | Removes a key-value from the transactional map | ```def remove(key: K): Txn[Unit]``` | Will throw an error if the key doesn't actually exist in the map (to be consistent with `get` behaviour) -`pure(10)` | Lifts a value into a transactional monad | ```def pure[V](value: V): Txn[V]``` | -`delay(10+2)` | Lifts a computation into a transactional monad (by-name value) | ```def delay[V](value: => V): Txn[V]``` | Argument will be evaluated every time a transaction is attempted. It is not advised to use with side effects. -`abort(new RuntimeException("foo"))` | Aborts the current transaction | ```def abort(ex: Throwable): Txn[Unit]``` | Variables/Maps changes in the transaction will not be changed if the transaction is aborted -`txn.handleErrorWith(_ => pure("bar"))` | Absorbs an error/abort and remaps to another transaction (of the same wrapped type) | ```def handleErrorWith(f: Throwable => Txn[V]): Txn[V]``` | -`waitFor(value > 10)` | Semantically blocks a transaction until a condition is met | ```def waitFor(predicate: => Boolean): Txn[Unit]``` | Blocking is only semantic (i.e. not locking up a thread while waiting)

This is implemented via retries that are initiated via variable/map updates. One can specify the `retryMaxWait` to facilitate backstop polling for these retries, but this is _not_ recommended (i.e. indicates side-effects are impacting predicate) +| Example | Description | Type Signature | Notes | +|:---------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `STM.runtime[F]` | Creates a runtime in an `F[_]` container whose transaction results can be lifted into a container `F[_]` via `commit` | ```def runtime[F[_]: Async](retryMaxWait: FiniteDuration, maxWaitingToProcessInLoop: Int): F[STM[F]]```
or
```def runtime[F[_]: Async]: F[STM[F]]``` (default `retryMaxWait`) | `retryMaxWait` is a backstop max amount of time to wait before retrying a transaction.

Default: `FiniteDuration(Long.MaxValue, NANOSECONDS)`

It is _not_ recommended to make this a small value (i.e. making retries effectively based on polling).

`maxWaitingToProcessInLoop` corresponds to max amount of waiting transactions the runtime will attempt to process in its runtime loop. It is not recommended to alter this value.

Default: `Runtime.getRuntime.availableProcessors() * 2` | +| `txnVar.get.commit` | Commits a transaction and lifts the result into `F[_]` | ```def commit: F[V]``` | | +| `TxnVar.of[List[Int]](List())` | Creates a transactional variable | ```def of[T](value: T): F[TxnVar[T]]``` | | +| `TxnVarMap.of[String, Int](Map())` | Creates a transactional map | ```of[K, V](valueMap: Map[K, V]): F[TxnVarMap[K, V]]``` | | +| `txnVar.get` | Retrieves value of transactional variable | ```def get: Txn[V]``` | | +| `txnVarMap.get` | Retrieves an immutable map (i.e. a view) representing transactional map state | ```def get: Txn[Map[K, V]]``` | Performance-wise it is better to retrieve individual keys instead of acquiring the entire map | +| `txnVarMap.get("David")` | Retrieves optional value depending on whether key exists in the map | ```def get(key: K): Txn[Option[V]]``` | Will raise an error if the key is never created (previously or current transaction). A `None` is returned if the value has been deleted in the current transaction. | +| `txnVar.set(100)` | Sets the value of transactional variable | ```def set(newValue: V): Txn[Unit]``` | | +| `txnVar.setF(Async[F].pure(100))` | Sets the value of transactional variable via an abstract effect wrapped in `F` | ```def setF[F[_]: Async](newValue: V): Txn[Unit]``` | Need to ensure `F[V]` does not encapsulate side-effects | +| `txnVarMap.set(Map("David" -> 100))` | Uses an immutable map to set the transactional map state | ```def set(newValueMap: Map[K, V]): Txn[Unit]``` | Performance-wise it is better to set individual keys instead of setting the entire map in this manner.

This operation will create/delete key-values as needed to update the state of the map. | +| `txnVarMap.setF(Async[F].pure(Map("David" -> 100)))` | Uses an immutable map (returned in an abstracted effect wrapped in `F`) to set the transactional map state | ```def setF[F[_]: Async](newValueMap: F[Map[K, V]]): Txn[Unit]``` | Need to ensure `F[V]` does not encapsulate side-effects | +| `txnVarMap.set("David", 100)` | Upserts the key-value into the transactional map | ```def set(key: K, newValue: V): Txn[Unit]``` | Will create the key-value in the transactional map, if the key was not present | +| `txnVarMap.setF("David", Async[F].pure(100))` | Upserts the key-value into the transactional map with the value being the result of an abstracted effect wrapped in `F` | ```def setF[F[_]: Async](key: K, newValue: F[V]): Txn[Unit]``` | Will create the key-value in the transactional map, if the key was not present

Need to ensure `F[V]` does not encapsulate side-effects | +| `txnVar.modify(_ + 5)` | Modifies the value of a transactional variable | ```def modify(f: V => V): Txn[Unit]``` | | +| `txnVar.modifyF(v => Async[F].delay(v + 5))` | Modifies the value of a transactional variable via an abstract effect wrapped in `F` | ```def modifyF[F[_]: Async](f: V => F[V]): Txn[Unit]``` | Need to ensure `F[V]` does not encapsulate side-effects | +| `txnVarMap.modify("David", _ + 20)` | Modifies the value in a transactional map for a given key | ```def modify(key: K, f: V => V): Txn[Unit]``` | Will throw an error if the `key` is not present in the map (or has been deleted in the current transaction) | +| `txnVarMap.modifyF("David", v => Async[F].delay(v + 20))` | Modifies the value in a transactional map for a given key via an abstract effect wrapped in `F` | ```def modifyF[F[_]: Async](key: K, f: V => F[V]): Txn[Unit]``` | Will throw an error if the `key` is not present in the map (or has been deleted in the current transaction)

Need to ensure `F[V]` does not encapsulate side-effects | +| `txnVarMap.modify(_.map(i => i._1 -> i._2*2))` | Modifies all the values in the map | ```def modify(f: Map[K, V] => Map[K, V]): Txn[Unit]``` | Transform can create/delete entries.

Again, for performance it is better to work with individual key-value pairs instead of manipulating map views | +| `txnVarMap.modifyF(m => Async[F].delay(m.map(i => i._1 -> i._2*2)))` | Modifies all the values in the map via an abstract effect wrapped in `F` | ```def modifyF[F[_]: Async](f: Map[K, V] => F[Map[K, V]]): Txn[Unit]``` | Transform can create/delete entries.

Again, for performance it is better to work with individual key-value pairs instead of manipulating map views

Need to ensure `F[V]` does not encapsulate side-effects | +| `txnVarMap.remove("David")` | Removes a key-value from the transactional map | ```def remove(key: K): Txn[Unit]``` | Will throw an error if the key doesn't actually exist in the map (to be consistent with `get` behaviour) | +| `pure(10)` | Lifts a value into a transactional monad | ```def pure[V](value: V): Txn[V]``` | | +| `delay(10+2)` | Lifts a computation into a transactional monad (by-name value) | ```def delay[V](value: => V): Txn[V]``` | Argument will be evaluated every time a transaction is attempted. It is not advised to use with side effects. | +| `abort(new RuntimeException("foo"))` | Aborts the current transaction | ```def abort(ex: Throwable): Txn[Unit]``` | Variables/Maps changes in the transaction will not be changed if the transaction is aborted | +| `txn.handleErrorWith(_ => pure("bar"))` | Absorbs an error/abort and remaps to another transaction (of the same wrapped type) | ```def handleErrorWith(f: Throwable => Txn[V]): Txn[V]``` | | +| `waitFor(value > 10)` | Semantically blocks a transaction until a condition is met | ```def waitFor(predicate: => Boolean): Txn[Unit]``` | Blocking is only semantic (i.e. not locking up a thread while waiting)

This is implemented via retries that are initiated via variable/map updates. One can specify the `retryMaxWait` to facilitate backstop polling for these retries, but this is _not_ recommended (i.e. indicates side-effects are impacting predicate) | ### Example Note in the below that the [better-monadic-for](https://github.com/oleg-py/better-monadic-for) compiler plugin is used to expose the STM runtime as an implicit in the monadic computation. This is avoids the use of `unsafeRunSync` to expose the runtime instance, while not requiring the runtime to be explicitly passed to the sub-programs. diff --git a/build.sbt b/build.sbt index c6fdb82..8daa0a2 100644 --- a/build.sbt +++ b/build.sbt @@ -1,4 +1,4 @@ -ThisBuild / baseVersion := "0.6.1" +ThisBuild / baseVersion := "0.7.0" ThisBuild / organization := "ai.entrolution" ThisBuild / organizationName := "Greg von Nessi" diff --git a/src/main/scala/bengal/stm/STM.scala b/src/main/scala/bengal/stm/STM.scala index 5479f41..8cc8c94 100644 --- a/src/main/scala/bengal/stm/STM.scala +++ b/src/main/scala/bengal/stm/STM.scala @@ -20,18 +20,21 @@ package bengal.stm import bengal.stm.api.internal.TxnApiContext import bengal.stm.model._ import bengal.stm.model.runtime._ -import bengal.stm.runtime.TxnRuntimeContext +import bengal.stm.runtime.{TxnCompilerContext, TxnLogContext, TxnRuntimeContext} import cats.effect.Ref -import cats.effect.implicits._ +import cats.effect.implicits.genSpawnOps import cats.effect.kernel.{Async, Deferred} import cats.effect.std.Semaphore import cats.implicits._ import scala.concurrent.duration.{FiniteDuration, NANOSECONDS} -trait STM[F[_]] - extends TxnRuntimeContext[F] +abstract class STM[F[_]: Async] + extends AsyncImplicits[F] + with TxnRuntimeContext[F] + with TxnCompilerContext[F] + with TxnLogContext[F] with TxnApiContext[F] with TxnAdtContext[F] { @@ -123,12 +126,12 @@ object STM { } override def allocateTxnVar[V](value: V): F[TxnVar[F, V]] = - TxnVar.of(value)(this, Async[F]) + TxnVar.of(value)(this, this.asyncF) override def allocateTxnVarMap[K, V]( valueMap: Map[K, V] ): F[TxnVarMap[F, K, V]] = - TxnVarMap.of(valueMap)(this, Async[F]) + TxnVarMap.of(valueMap)(this, this.asyncF) override private[stm] def commitTxn[V](txn: Txn[V]): F[V] = txnRuntime.commit(txn) diff --git a/src/main/scala/bengal/stm/api/internal/TxnApiContext.scala b/src/main/scala/bengal/stm/api/internal/TxnApiContext.scala index e187cdc..d8cc98c 100644 --- a/src/main/scala/bengal/stm/api/internal/TxnApiContext.scala +++ b/src/main/scala/bengal/stm/api/internal/TxnApiContext.scala @@ -19,13 +19,15 @@ package bengal.stm.api.internal import bengal.stm.model.TxnErratum._ import bengal.stm.model._ +import bengal.stm.runtime.TxnRuntimeContext +import cats.effect.kernel.Async import cats.free.Free import scala.util.{Failure, Success, Try} private[stm] trait TxnApiContext[F[_]] { - this: TxnAdtContext[F] => + this: AsyncImplicits[F] with TxnRuntimeContext[F] with TxnAdtContext[F] => private def liftSuccess[V](txnAdt: TxnAdt[V]): Txn[V] = Free.liftF[TxnOrErr, V](Right(txnAdt)) @@ -36,8 +38,11 @@ private[stm] trait TxnApiContext[F[_]] { val unit: Txn[Unit] = liftSuccess(TxnUnit) + def fromF[V](spec: F[V]): Txn[V] = + liftSuccess(TxnDelay(spec)) + def delay[V](thunk: => V): Txn[V] = - liftSuccess(TxnDelay(() => thunk)) + liftSuccess(TxnDelay(Async[F].delay(thunk))) def pure[V](value: V): Txn[V] = liftSuccess(TxnPure(value)) @@ -58,7 +63,12 @@ private[stm] trait TxnApiContext[F[_]] { private[stm] def handleErrorWithInternal[V](fa: => Txn[V])( f: Throwable => Txn[V] ): Txn[V] = - liftSuccess(TxnHandleError(() => fa, f)) + liftSuccess(TxnHandleError(Async[F].delay(fa), ex => Async[F].delay(f(ex)))) + + private[stm] def handleErrorWithInternalF[V](fa: => Txn[V])( + f: Throwable => F[Txn[V]] + ): Txn[V] = + liftSuccess(TxnHandleError(Async[F].delay(fa), f)) private[stm] def getTxnVar[V](txnVar: TxnVar[F, V]): Txn[V] = liftSuccess(TxnGetVar(txnVar)) @@ -67,7 +77,13 @@ private[stm] trait TxnApiContext[F[_]] { newValue: => V, txnVar: TxnVar[F, V] ): Txn[Unit] = - liftSuccess(TxnSetVar(() => newValue, txnVar)) + liftSuccess(TxnSetVar(Async[F].delay(newValue), txnVar)) + + private[stm] def setTxnVarF[V]( + newValue: F[V], + txnVar: TxnVar[F, V] + ): Txn[Unit] = + liftSuccess(TxnSetVar(newValue, txnVar)) private[stm] def modifyTxnVar[V](f: V => V, txnVar: TxnVar[F, V]): Txn[Unit] = for { @@ -75,6 +91,15 @@ private[stm] trait TxnApiContext[F[_]] { _ <- setTxnVar(f(value), txnVar) } yield () + private[stm] def modifyTxnVarF[V]( + f: V => F[V], + txnVar: TxnVar[F, V] + ): Txn[Unit] = + for { + value <- getTxnVar(txnVar) + _ <- setTxnVarF(f(value), txnVar) + } yield () + private[stm] def getTxnVarMap[K, V]( txnVarMap: TxnVarMap[F, K, V] ): Txn[Map[K, V]] = @@ -84,7 +109,13 @@ private[stm] trait TxnApiContext[F[_]] { newValueMap: => Map[K, V], txnVarMap: TxnVarMap[F, K, V] ): Txn[Unit] = - liftSuccess(TxnSetVarMap(() => newValueMap, txnVarMap)) + liftSuccess(TxnSetVarMap(Async[F].delay(newValueMap), txnVarMap)) + + private[stm] def setTxnVarMapF[K, V]( + newValueMap: F[Map[K, V]], + txnVarMap: TxnVarMap[F, K, V] + ): Txn[Unit] = + liftSuccess(TxnSetVarMap(newValueMap, txnVarMap)) private[stm] def modifyTxnVarMap[K, V]( f: Map[K, V] => Map[K, V], @@ -95,29 +126,66 @@ private[stm] trait TxnApiContext[F[_]] { _ <- setTxnVarMap(f(value), txnVarMap) } yield () + private[stm] def modifyTxnVarMapF[K, V]( + f: Map[K, V] => F[Map[K, V]], + txnVarMap: TxnVarMap[F, K, V] + ): Txn[Unit] = + for { + value <- getTxnVarMap(txnVarMap) + _ <- setTxnVarMapF(f(value), txnVarMap) + } yield () + private[stm] def getTxnVarMapValue[K, V]( key: => K, txnVarMap: TxnVarMap[F, K, V] ): Txn[Option[V]] = - liftSuccess(TxnGetVarMapValue(() => key, txnVarMap)) + liftSuccess(TxnGetVarMapValue(Async[F].delay(key), txnVarMap)) private[stm] def setTxnVarMapValue[K, V]( key: => K, newValue: => V, txnVarMap: TxnVarMap[F, K, V] ): Txn[Unit] = - liftSuccess(TxnSetVarMapValue(() => key, () => newValue, txnVarMap)) + liftSuccess( + TxnSetVarMapValue(Async[F].delay(key), + Async[F].delay(newValue), + txnVarMap + ) + ) + + private[stm] def setTxnVarMapValueF[K, V]( + key: => K, + newValue: F[V], + txnVarMap: TxnVarMap[F, K, V] + ): Txn[Unit] = + liftSuccess( + TxnSetVarMapValue(Async[F].delay(key), newValue, txnVarMap) + ) private[stm] def modifyTxnVarMapValue[K, V]( key: => K, f: V => V, txnVarMap: TxnVarMap[F, K, V] ): Txn[Unit] = - liftSuccess(TxnModifyVarMapValue(() => key, f, txnVarMap)) + liftSuccess( + TxnModifyVarMapValue(Async[F].delay(key), + (v: V) => Async[F].delay(f(v)), + txnVarMap + ) + ) + + private[stm] def modifyTxnVarMapValueF[K, V]( + key: => K, + f: V => F[V], + txnVarMap: TxnVarMap[F, K, V] + ): Txn[Unit] = + liftSuccess( + TxnModifyVarMapValue(Async[F].delay(key), f, txnVarMap) + ) private[stm] def removeTxnVarMapValue[K, V]( key: => K, txnVarMap: TxnVarMap[F, K, V] ): Txn[Unit] = - liftSuccess(TxnDeleteVarMapValue(() => key, txnVarMap)) + liftSuccess(TxnDeleteVarMapValue(Async[F].delay(key), txnVarMap)) } diff --git a/src/main/scala/bengal/stm/model/AsyncImplicits.scala b/src/main/scala/bengal/stm/model/AsyncImplicits.scala new file mode 100644 index 0000000..dfaac31 --- /dev/null +++ b/src/main/scala/bengal/stm/model/AsyncImplicits.scala @@ -0,0 +1,24 @@ +/* + * Copyright 2020-2022 Greg von Nessi + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.entrolution +package bengal.stm.model + +import cats.effect.kernel.Async + +private[stm] abstract class AsyncImplicits[F[_]]( + protected implicit val asyncF: Async[F] +) diff --git a/src/main/scala/bengal/stm/model/TxnAdtContext.scala b/src/main/scala/bengal/stm/model/TxnAdtContext.scala index 08158a9..8ff4801 100644 --- a/src/main/scala/bengal/stm/model/TxnAdtContext.scala +++ b/src/main/scala/bengal/stm/model/TxnAdtContext.scala @@ -21,14 +21,14 @@ private[stm] trait TxnAdtContext[F[_]] { private[stm] case object TxnUnit extends TxnAdt[Unit] - private[stm] case class TxnDelay[V](thunk: () => V) extends TxnAdt[V] + private[stm] case class TxnDelay[V](thunk: F[V]) extends TxnAdt[V] private[stm] case class TxnPure[V](value: V) extends TxnAdt[V] private[stm] case class TxnGetVar[V](txnVar: TxnVar[F, V]) extends TxnAdt[V] private[stm] case class TxnSetVar[V]( - newValue: () => V, + newValue: F[V], txnVar: TxnVar[F, V] ) extends TxnAdt[Unit] @@ -36,34 +36,34 @@ private[stm] trait TxnAdtContext[F[_]] { extends TxnAdt[Map[K, V]] private[stm] case class TxnGetVarMapValue[K, V]( - key: () => K, + key: F[K], txnVarMap: TxnVarMap[F, K, V] ) extends TxnAdt[Option[V]] private[stm] case class TxnSetVarMap[K, V]( - newMap: () => Map[K, V], + newMap: F[Map[K, V]], txnVarMap: TxnVarMap[F, K, V] ) extends TxnAdt[Unit] private[stm] case class TxnSetVarMapValue[K, V]( - key: () => K, - newValue: () => V, + key: F[K], + newValue: F[V], txnVarMap: TxnVarMap[F, K, V] ) extends TxnAdt[Unit] private[stm] case class TxnModifyVarMapValue[K, V]( - key: () => K, - f: V => V, + key: F[K], + f: V => F[V], txnVarMap: TxnVarMap[F, K, V] ) extends TxnAdt[Unit] private[stm] case class TxnDeleteVarMapValue[K, V]( - key: () => K, + key: F[K], txnVarMap: TxnVarMap[F, K, V] ) extends TxnAdt[Unit] private[stm] case class TxnHandleError[V]( - fa: () => Txn[V], - f: Throwable => Txn[V] + fa: F[Txn[V]], + f: Throwable => F[Txn[V]] ) extends TxnAdt[V] } diff --git a/src/main/scala/bengal/stm/model/TxnErratum.scala b/src/main/scala/bengal/stm/model/TxnErratum.scala index 35fedc9..37342b8 100644 --- a/src/main/scala/bengal/stm/model/TxnErratum.scala +++ b/src/main/scala/bengal/stm/model/TxnErratum.scala @@ -20,7 +20,6 @@ package bengal.stm.model private[stm] sealed trait TxnErratum object TxnErratum { - private[stm] case object NoErratum extends TxnErratum private[stm] case object TxnRetry extends TxnErratum diff --git a/src/main/scala/bengal/stm/model/TxnVar.scala b/src/main/scala/bengal/stm/model/TxnVar.scala index 82caaba..4404d69 100644 --- a/src/main/scala/bengal/stm/model/TxnVar.scala +++ b/src/main/scala/bengal/stm/model/TxnVar.scala @@ -20,7 +20,6 @@ package bengal.stm.model import bengal.stm.STM import bengal.stm.model.runtime._ -import cats.effect.implicits._ import cats.effect.kernel.Async import cats.effect.std.Semaphore import cats.effect.{Deferred, Ref} @@ -33,10 +32,10 @@ case class TxnVar[F[_]: Async, T]( private[stm] val txnRetrySignals: TxnSignals[F] ) extends TxnStateEntity[F, T] { - private def completeRetrySignals: F[Unit] = + private[stm] def completeRetrySignals: F[Unit] = for { signals <- txnRetrySignals.getAndSet(Set()) - _ <- signals.toList.parTraverse(_.complete(())) + _ <- signals.toList.traverse(_.complete(())) } yield () private[stm] lazy val get: F[T] = diff --git a/src/main/scala/bengal/stm/model/TxnVarMap.scala b/src/main/scala/bengal/stm/model/TxnVarMap.scala index 5ab3cac..e79059e 100644 --- a/src/main/scala/bengal/stm/model/TxnVarMap.scala +++ b/src/main/scala/bengal/stm/model/TxnVarMap.scala @@ -20,7 +20,6 @@ package bengal.stm.model import bengal.stm.STM import bengal.stm.model.runtime._ -import cats.effect.implicits._ import cats.effect.kernel.Async import cats.effect.std.Semaphore import cats.effect.{Deferred, Ref} @@ -38,6 +37,12 @@ case class TxnVarMap[F[_]: STM: Async, K, V]( private[stm] val txnRetrySignals: TxnSignals[F] ) extends TxnStateEntity[F, VarIndex[F, K, V]] { + private def completeRetrySignals: F[Unit] = + for { + signals <- txnRetrySignals.getAndSet(Set()) + _ <- signals.toList.traverse(_.complete(())) + } yield () + private def withLock[A](semaphore: Semaphore[F])( fa: F[A] ): F[A] = @@ -46,7 +51,7 @@ case class TxnVarMap[F[_]: STM: Async, K, V]( private[stm] lazy val get: F[Map[K, V]] = for { txnVarMap <- value.get - valueMap <- txnVarMap.toList.parTraverse { kv => + valueMap <- txnVarMap.toList.traverse { kv => kv._2.get.map(v => kv._1 -> v) } } yield valueMap.toMap @@ -91,7 +96,7 @@ case class TxnVarMap[F[_]: STM: Async, K, V]( keySet: Set[K] ): F[Set[TxnVarId]] = for { - ids <- keySet.toList.parTraverse(getId) + ids <- keySet.toList.traverse(getId) } yield ids.flatten.toSet // Only called when key is known to not exist @@ -101,6 +106,7 @@ case class TxnVarMap[F[_]: STM: Async, K, V]( _ <- withLock(internalStructureLock)( value.update(_ += (newKey -> newTxnVar)) ) + _ <- completeRetrySignals } yield () private[stm] def addOrUpdate(key: K, newValue: V): F[Unit] = @@ -108,7 +114,9 @@ case class TxnVarMap[F[_]: STM: Async, K, V]( txnVarMap <- value.get _ <- txnVarMap.get(key) match { case Some(tVar) => - withLock(internalStructureLock)(tVar.set(newValue)) + withLock(internalStructureLock)( + tVar.set(newValue) + ) >> completeRetrySignals case None => add(key, newValue) } @@ -118,14 +126,18 @@ case class TxnVarMap[F[_]: STM: Async, K, V]( for { txnVarMap <- value.get _ <- txnVarMap.get(key) match { - case Some(_) => - withLock(internalStructureLock)(value.update(_ -= key)) + case Some(txnVar) => + for { + _ <- withLock(internalStructureLock)(value.update(_ -= key)) + _ <- txnVar.completeRetrySignals + _ <- completeRetrySignals + } yield () case None => Async[F].unit } } yield () - override private[stm] def registerRetry( + private[stm] override def registerRetry( signal: Deferred[F, Unit] ): F[Unit] = withLock(internalSignalLock)(txnRetrySignals.update(_ + signal)) diff --git a/src/main/scala/bengal/stm/model/runtime/IdClosure.scala b/src/main/scala/bengal/stm/model/runtime/IdClosure.scala index b3824c2..32fb56c 100644 --- a/src/main/scala/bengal/stm/model/runtime/IdClosure.scala +++ b/src/main/scala/bengal/stm/model/runtime/IdClosure.scala @@ -17,8 +17,6 @@ package ai.entrolution package bengal.stm.model.runtime -import scala.collection.concurrent.{TrieMap, Map => ConcurrentMap} - private[stm] case class IdClosure( readIds: Set[TxnVarRuntimeId], updatedIds: Set[TxnVarRuntimeId] @@ -47,67 +45,3 @@ private[stm] case class IdClosure( private[stm] object IdClosure { private[stm] val empty: IdClosure = IdClosure(Set(), Set()) } - -private[stm] case class IdClosureTallies( - private val readIdTallies: ConcurrentMap[TxnVarRuntimeId, Int], - private val updatedIdTallies: ConcurrentMap[TxnVarRuntimeId, Int] -) { - - private def addReadId(id: TxnVarRuntimeId): Unit = - readIdTallies += (id -> (readIdTallies.getOrElse(id, 0) + 1)) - - private def removeReadId(id: TxnVarRuntimeId): Unit = { - val newValue: Int = readIdTallies.getOrElse(id, 0) - 1 - if (newValue < 1) { - readIdTallies -= id - } else { - readIdTallies += (id -> newValue) - } - } - - private def addUpdateId(id: TxnVarRuntimeId): Unit = - updatedIdTallies += (id -> (updatedIdTallies.getOrElse(id, 0) + 1)) - - private def removeUpdateId(id: TxnVarRuntimeId): Unit = { - val newValue: Int = updatedIdTallies.getOrElse(id, 0) - 1 - if (newValue < 1) { - updatedIdTallies -= id - } else { - updatedIdTallies += (id -> newValue) - } - } - - private def addReadIds(ids: Set[TxnVarRuntimeId]): Unit = - ids.foreach(addReadId) - - private def removeReadIds(ids: Set[TxnVarRuntimeId]): Unit = - ids.foreach(removeReadId) - - private def addUpdateIds(ids: Set[TxnVarRuntimeId]): Unit = - ids.foreach(addUpdateId) - - private def removeUpdateIds(ids: Set[TxnVarRuntimeId]): Unit = - ids.foreach(removeUpdateId) - - private[stm] def addIdClosure(idClosure: IdClosure): Unit = { - addReadIds(idClosure.readIds) - addUpdateIds(idClosure.updatedIds) - } - - private[stm] def removeIdClosure(idClosure: IdClosure): Unit = { - removeReadIds(idClosure.readIds) - removeUpdateIds(idClosure.updatedIds) - } - - private[stm] def getIdClosure: IdClosure = - IdClosure( - readIds = readIdTallies.keySet.toSet, - updatedIds = updatedIdTallies.keySet.toSet - ) -} - -private[stm] object IdClosureTallies { - - private[stm] def empty: IdClosureTallies = - IdClosureTallies(TrieMap.empty, TrieMap.empty) -} diff --git a/src/main/scala/bengal/stm/model/runtime/IdClosureTallies.scala b/src/main/scala/bengal/stm/model/runtime/IdClosureTallies.scala new file mode 100644 index 0000000..9c79c6a --- /dev/null +++ b/src/main/scala/bengal/stm/model/runtime/IdClosureTallies.scala @@ -0,0 +1,84 @@ +/* + * Copyright 2020-2022 Greg von Nessi + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ai.entrolution +package bengal.stm.model.runtime + +import scala.collection.concurrent.{TrieMap, Map => ConcurrentMap} + +private[stm] case class IdClosureTallies( + private val readIdTallies: ConcurrentMap[TxnVarRuntimeId, Int], + private val updatedIdTallies: ConcurrentMap[TxnVarRuntimeId, Int] +) { + + private def addReadId(id: TxnVarRuntimeId): Unit = + readIdTallies += (id -> (readIdTallies.getOrElse(id, 0) + 1)) + + private def removeReadId(id: TxnVarRuntimeId): Unit = { + val newValue: Int = readIdTallies.getOrElse(id, 0) - 1 + if (newValue < 1) { + readIdTallies -= id + } else { + readIdTallies += (id -> newValue) + } + } + + private def addUpdateId(id: TxnVarRuntimeId): Unit = + updatedIdTallies += (id -> (updatedIdTallies.getOrElse(id, 0) + 1)) + + private def removeUpdateId(id: TxnVarRuntimeId): Unit = { + val newValue: Int = updatedIdTallies.getOrElse(id, 0) - 1 + if (newValue < 1) { + updatedIdTallies -= id + } else { + updatedIdTallies += (id -> newValue) + } + } + + private def addReadIds(ids: Set[TxnVarRuntimeId]): Unit = + ids.foreach(addReadId) + + private def removeReadIds(ids: Set[TxnVarRuntimeId]): Unit = + ids.foreach(removeReadId) + + private def addUpdateIds(ids: Set[TxnVarRuntimeId]): Unit = + ids.foreach(addUpdateId) + + private def removeUpdateIds(ids: Set[TxnVarRuntimeId]): Unit = + ids.foreach(removeUpdateId) + + private[stm] def addIdClosure(idClosure: IdClosure): Unit = { + addReadIds(idClosure.readIds) + addUpdateIds(idClosure.updatedIds) + } + + private[stm] def removeIdClosure(idClosure: IdClosure): Unit = { + removeReadIds(idClosure.readIds) + removeUpdateIds(idClosure.updatedIds) + } + + private[stm] lazy val getIdClosure: IdClosure = + IdClosure( + readIds = readIdTallies.keySet.toSet, + updatedIds = updatedIdTallies.keySet.toSet + ) +} + +private[stm] object IdClosureTallies { + + private[stm] val empty: IdClosureTallies = + IdClosureTallies(TrieMap.empty, TrieMap.empty) +} diff --git a/src/main/scala/bengal/stm/runtime/TxnCompilerContext.scala b/src/main/scala/bengal/stm/runtime/TxnCompilerContext.scala index 961a67f..d7af64d 100644 --- a/src/main/scala/bengal/stm/runtime/TxnCompilerContext.scala +++ b/src/main/scala/bengal/stm/runtime/TxnCompilerContext.scala @@ -26,11 +26,8 @@ import cats.data.StateT import cats.effect.kernel.Async import cats.syntax.all._ -import scala.util.{Failure, Success, Try} - -private[stm] abstract class TxnCompilerContext[F[_]: Async] - extends TxnLogContext[F] { - this: TxnAdtContext[F] => +private[stm] trait TxnCompilerContext[F[_]] { + this: AsyncImplicits[F] with TxnLogContext[F] with TxnAdtContext[F] => private[stm] type IdClosureStore[T] = StateT[F, IdClosure, T] private[stm] type TxnLogStore[T] = StateT[F, TxnLog, T] @@ -53,13 +50,10 @@ private[stm] abstract class TxnCompilerContext[F[_]: Async] noOp[IdClosure].map(_.asInstanceOf[V]) case TxnDelay(thunk) => StateT[F, IdClosure, V] { s => - Async[F].delay { - Try(thunk()) match { - case Success(materializedValue) => - (s, materializedValue) - case _ => - throw StaticAnalysisShortCircuitException(s) - } + thunk.map { materializedValue => + (s, materializedValue) + }.handleErrorWith { _ => + Async[F].raiseError(StaticAnalysisShortCircuitException(s)) } } case TxnPure(value) => @@ -79,32 +73,29 @@ private[stm] abstract class TxnCompilerContext[F[_]: Async] } case adt: TxnGetVarMapValue[_, _] => StateT[F, IdClosure, V] { s => - Try(adt.key()) match { - case Success(materializedKey) => - for { - oTxnVar <- - adt.txnVarMap.getTxnVar(materializedKey) - value <- - oTxnVar - .map(_.get.map(Some(_))) - .getOrElse(Async[F].pure(None)) - oARId <- - adt.txnVarMap.getRuntimeActualisedId( - materializedKey - ) - eRId = - adt.txnVarMap.getRuntimeExistentialId( - materializedKey - ) - } yield oARId - .map(id => - (s.addReadId(id).addReadId(eRId), - value.asInstanceOf[V] - ) + adt.key.flatMap { materializedKey => + for { + oTxnVar <- + adt.txnVarMap.getTxnVar(materializedKey) + value <- + oTxnVar + .map(_.get.map(Some(_))) + .getOrElse(Async[F].pure(None)) + oARId <- + adt.txnVarMap.getRuntimeActualisedId( + materializedKey + ) + eRId = + adt.txnVarMap.getRuntimeExistentialId( + materializedKey ) - .getOrElse((s.addReadId(eRId), value.asInstanceOf[V])) - case _ => - throw StaticAnalysisShortCircuitException(s) + } yield oARId + .map(id => + (s.addReadId(id).addReadId(eRId), value.asInstanceOf[V]) + ) + .getOrElse((s.addReadId(eRId), value.asInstanceOf[V])) + }.handleErrorWith { _ => + Async[F].raiseError(StaticAnalysisShortCircuitException(s)) } } case adt: TxnSetVar[_] => @@ -121,74 +112,73 @@ private[stm] abstract class TxnCompilerContext[F[_]: Async] } case adt: TxnSetVarMapValue[_, _] => StateT[F, IdClosure, Unit] { s => - Try(adt.key()) match { - case Success(materializedKey) => - for { - oARId <- - adt.txnVarMap.getRuntimeActualisedId( - materializedKey - ) - eRId = - adt.txnVarMap.getRuntimeExistentialId( - materializedKey - ) - } yield oARId - .map(id => (s.addWriteId(id).addWriteId(eRId), ())) - .getOrElse((s.addWriteId(eRId), ())) - case _ => - Async[F].delay((s, ())) + adt.key.flatMap { materializedKey => + for { + oARId <- + adt.txnVarMap.getRuntimeActualisedId( + materializedKey + ) + eRId = + adt.txnVarMap.getRuntimeExistentialId( + materializedKey + ) + } yield oARId + .map(id => (s.addWriteId(id).addWriteId(eRId), ())) + .getOrElse((s.addWriteId(eRId), ())) + }.handleErrorWith { _ => + Async[F].delay((s, ())) } }.map(_.asInstanceOf[V]) case adt: TxnModifyVarMapValue[_, _] => StateT[F, IdClosure, Unit] { s => - Try(adt.key()) match { - case Success(materializedKey) => - for { - oARId <- - adt.txnVarMap.getRuntimeActualisedId( - materializedKey - ) - eRId = - adt.txnVarMap.getRuntimeExistentialId( - materializedKey - ) - } yield oARId - .map(id => (s.addWriteId(id).addWriteId(eRId), ())) - .getOrElse((s.addWriteId(eRId), ())) - case _ => - Async[F].delay((s, ())) + adt.key.flatMap { materializedKey => + for { + oARId <- + adt.txnVarMap.getRuntimeActualisedId( + materializedKey + ) + eRId = + adt.txnVarMap.getRuntimeExistentialId( + materializedKey + ) + } yield oARId + .map(id => (s.addWriteId(id).addWriteId(eRId), ())) + .getOrElse((s.addWriteId(eRId), ())) + }.handleErrorWith { _ => + Async[F].delay((s, ())) } }.map(_.asInstanceOf[V]) case adt: TxnDeleteVarMapValue[_, _] => StateT[F, IdClosure, Unit] { s => - Try(adt.key()) match { - case Success(materializedKey) => - for { - oARId <- - adt.txnVarMap.getRuntimeActualisedId( - materializedKey - ) - eRId = - adt.txnVarMap.getRuntimeExistentialId( - materializedKey - ) - } yield oARId - .map(id => (s.addWriteId(id).addWriteId(eRId), ())) - .getOrElse((s.addWriteId(eRId), ())) - case _ => - Async[F].delay((s, ())) + adt.key.flatMap { materializedKey => + for { + oARId <- + adt.txnVarMap.getRuntimeActualisedId( + materializedKey + ) + eRId = + adt.txnVarMap.getRuntimeExistentialId( + materializedKey + ) + } yield oARId + .map(id => (s.addWriteId(id).addWriteId(eRId), ())) + .getOrElse((s.addWriteId(eRId), ())) + }.handleErrorWith { _ => + Async[F].delay((s, ())) } }.map(_.asInstanceOf[V]) case adt: TxnHandleError[_] => StateT[F, IdClosure, V] { s => - Try(adt.fa().map(_.asInstanceOf[V])) match { - case Success(materializedF) => + adt.fa + .map(_.map(_.asInstanceOf[V])) + .flatMap { materializedF => materializedF .foldMap(staticAnalysisCompiler) .run(s) - case _ => + } + .handleErrorWith { _ => Async[F].delay((s, ().asInstanceOf[V])) - } + } } case _ => noOp[IdClosure].map(_.asInstanceOf[V]) @@ -261,23 +251,25 @@ private[stm] abstract class TxnCompilerContext[F[_]: Async] } case adt: TxnHandleError[_] => StateT[F, TxnLog, V] { s => - Try(adt.fa()) match { - case Success(materializedF) => - for { - originalResult <- - materializedF.foldMap(txnLogCompiler).run(s) - finalResult <- originalResult._1 match { - case TxnLogError(ex) => - adt - .f(ex) - .foldMap(txnLogCompiler) + (for { + materializedF <- adt.fa + originalResult <- + materializedF.foldMap(txnLogCompiler).run(s) + finalResult <- originalResult._1 match { + case TxnLogError(ex) => + adt + .f(ex) + .flatMap { + _.foldMap(txnLogCompiler) .run(s) - case _ => - Async[F].pure(originalResult) - } - } yield finalResult - case Failure(exception) => - s.raiseError(exception).map((_, ().asInstanceOf[V])) + } + case _ => + Async[F].pure(originalResult) + } + } yield (finalResult._1, + finalResult._2.asInstanceOf[V] + )).handleErrorWith { ex => + s.raiseError(ex).map((_, ().asInstanceOf[V])) } } case _ => diff --git a/src/main/scala/bengal/stm/runtime/TxnLogContext.scala b/src/main/scala/bengal/stm/runtime/TxnLogContext.scala index 2c1f457..8d8a01e 100644 --- a/src/main/scala/bengal/stm/runtime/TxnLogContext.scala +++ b/src/main/scala/bengal/stm/runtime/TxnLogContext.scala @@ -21,15 +21,14 @@ import bengal.stm.model._ import bengal.stm.model.runtime._ import cats.effect.Deferred -import cats.effect.implicits._ import cats.effect.kernel.{Async, Resource} import cats.effect.std.Semaphore import cats.syntax.all._ import scala.annotation.nowarn -import scala.util.{Failure, Success, Try} -private[stm] abstract class TxnLogContext[F[_]: Async] { +private[stm] trait TxnLogContext[F[_]] { + this: AsyncImplicits[F] => private[stm] sealed trait TxnLogEntry[V] { private[stm] def get: V @@ -357,7 +356,7 @@ private[stm] abstract class TxnLogContext[F[_]: Async] { private[stm] def getVar[V](txnVar: TxnVar[F, V]): F[(TxnLog, V)] @nowarn - private[stm] def delay[V](value: () => V): F[(TxnLog, V)] = + private[stm] def delay[V](value: F[V]): F[(TxnLog, V)] = Async[F].delay(self, ().asInstanceOf[V]) @nowarn @@ -366,14 +365,14 @@ private[stm] abstract class TxnLogContext[F[_]: Async] { @nowarn private[stm] def setVar[V]( - newValue: () => V, + newValue: F[V], txnVar: TxnVar[F, V] ): F[TxnLog] = Async[F].pure(self) @nowarn private[stm] def getVarMapValue[K, V]( - key: () => K, + key: F[K], txnVarMap: TxnVarMap[F, K, V] ): F[(TxnLog, Option[V])] = Async[F].pure((self, None)) @@ -386,30 +385,30 @@ private[stm] abstract class TxnLogContext[F[_]: Async] { @nowarn private[stm] def setVarMap[K, V]( - newMap: () => Map[K, V], + newMap: F[Map[K, V]], txnVarMap: TxnVarMap[F, K, V] ): F[TxnLog] = Async[F].pure(self) @nowarn private[stm] def setVarMapValue[K, V]( - key: () => K, - newValue: () => V, + key: F[K], + newValue: F[V], txnVarMap: TxnVarMap[F, K, V] ): F[TxnLog] = Async[F].pure(self) @nowarn private[stm] def modifyVarMapValue[K, V]( - key: () => K, - f: V => V, + key: F[K], + f: V => F[V], txnVarMap: TxnVarMap[F, K, V] ): F[TxnLog] = Async[F].pure(self) @nowarn private[stm] def deleteVarMapValue[K, V]( - key: () => K, + key: F[K], txnVarMap: TxnVarMap[F, K, V] ): F[TxnLog] = Async[F].pure(self) @@ -438,16 +437,13 @@ private[stm] abstract class TxnLogContext[F[_]: Async] { import TxnLogValid._ override private[stm] def delay[V]( - value: () => V + value: F[V] ): F[(TxnLog, V)] = - Async[F].delay { - Try(value()) match { - case Success(v) => - (this, v) - case Failure(exception) => - (TxnLogError(exception), ().asInstanceOf[V]) + value + .map((this.asInstanceOf[TxnLog], _)) + .handleErrorWith { ex => + Async[F].pure((TxnLogError(ex), ().asInstanceOf[V])) } - } override private[stm] def pure[V](value: V): F[(TxnLog, V)] = Async[F].pure((this, value)) @@ -472,34 +468,32 @@ private[stm] abstract class TxnLogContext[F[_]: Async] { } override private[stm] def setVar[V]( - newValue: () => V, + newValue: F[V], txnVar: TxnVar[F, V] ): F[TxnLog] = - Try(newValue()) match { - case Success(materializedValue) => - (log.get(txnVar.runtimeId) match { - case Some(entry) => - Async[F].delay( - this.copy( - log + (txnVar.runtimeId -> entry - .asInstanceOf[TxnLogEntry[V]] - .set(materializedValue)) - ) + newValue.flatMap { materializedValue => + (log.get(txnVar.runtimeId) match { + case Some(entry) => + Async[F].delay( + this.copy( + log + (txnVar.runtimeId -> entry + .asInstanceOf[TxnLogEntry[V]] + .set(materializedValue)) ) - case _ => - txnVar.get.map { v => - this.copy( - log + (txnVar.runtimeId -> TxnLogUpdateVarEntry( - v, - materializedValue, - txnVar - )) - ) - } - }).map(_.asInstanceOf[TxnLog]) - case Failure(exception) => - raiseError(exception) + ) + case _ => + txnVar.get.map { v => + this.copy( + log + (txnVar.runtimeId -> TxnLogUpdateVarEntry( + v, + materializedValue, + txnVar + )) + ) + } + }).map(_.asInstanceOf[TxnLog]) } + .handleErrorWith(raiseError) private def getVarMapValueEntry[K, V]( key: K, @@ -557,7 +551,7 @@ private[stm] abstract class TxnLogContext[F[_]: Async] { for { oldMap <- txnVarMap.get preTxn <- - oldMap.keySet.toList.parTraverse { ks => + oldMap.keySet.toList.traverse { ks => getVarMapValueEntry(ks, txnVarMap) } } yield preTxn @@ -565,7 +559,7 @@ private[stm] abstract class TxnLogContext[F[_]: Async] { Async[F].pure(List()) } currentEntries <- extractMap(txnVarMap, log) - reads <- currentEntries.keySet.toList.parTraverse { ks => + reads <- currentEntries.keySet.toList.traverse { ks => getVarMapValueEntry(ks, txnVarMap) } } yield (preTxnEntries ::: reads).flatten.toMap @@ -592,54 +586,48 @@ private[stm] abstract class TxnLogContext[F[_]: Async] { } override private[stm] def getVarMapValue[K, V]( - key: () => K, + key: F[K], txnVarMap: TxnVarMap[F, K, V] ): F[(TxnLog, Option[V])] = - Try(key()) match { - case Success(materializedKey) => - for { - oTxnVar <- txnVarMap.getTxnVar(materializedKey) - result <- oTxnVar match { - case Some(txnVar) => - log.get(txnVar.runtimeId) match { - case Some(entry) => - Async[F].delay( - (this, entry.get.asInstanceOf[Option[V]]) - ) //Noop - case None => - for { - txnVal <- txnVar.get - } yield (this.copy( - log + (txnVar.runtimeId -> TxnLogReadOnlyVarMapEntry( - materializedKey, - Some(txnVal), - txnVarMap - )) - ), - Some(txnVal) - ) - } + (for { + materializedKey <- key + oTxnVar <- txnVarMap.getTxnVar(materializedKey) + result <- (oTxnVar match { + case Some(txnVar) => + log.get(txnVar.runtimeId) match { + case Some(entry) => + Async[F].delay( + (this, entry.get) + ) //Noop case None => for { - rids <- txnVarMap.getRuntimeId( - materializedKey - ) - } yield rids.flatMap(log.get) match { - case entry :: Nil => - (this, entry.get.asInstanceOf[Option[V]]) //Noop - case _ => - (TxnLogError { - new RuntimeException( - s"Tried to read non-existent key $key in transactional map" - ) - }, - None - ) - } + txnVal <- txnVar.get + } yield (this.copy( + log + (txnVar.runtimeId -> TxnLogReadOnlyVarMapEntry( + materializedKey, + Some(txnVal), + txnVarMap + )) + ), + Some(txnVal) + ) } - } yield result - case Failure(exception) => - raiseError(exception).map(log => (log, None)) + case None => + for { + rids <- txnVarMap.getRuntimeId( + materializedKey + ) + } yield rids.flatMap(log.get) match { + case entry :: Nil => + (this, entry.get) //Noop + case _ => + (this, None) + } + }).map { case (log, value) => + (log.asInstanceOf[TxnLog], value.asInstanceOf[Option[V]]) + } + } yield result).handleErrorWith { ex => + raiseError(ex).map(log => (log, None)) } private def setVarMapValueEntry[K, V]( @@ -758,245 +746,246 @@ private[stm] abstract class TxnLogContext[F[_]: Async] { } yield result override private[stm] def setVarMap[K, V]( - newMap: () => Map[K, V], + newMap: F[Map[K, V]], txnVarMap: TxnVarMap[F, K, V] ): F[TxnLog] = - Try(newMap()) match { - case Success(materializedNewMap) => - val individualEntries: F[Map[TxnVarRuntimeId, TxnLogEntry[_]]] = for { - currentMap <- extractMap(txnVarMap, log) - deletions <- - (currentMap.keySet -- materializedNewMap.keySet).toList.parTraverse { - ks => - deleteVarMapValueEntry(ks, txnVarMap) - } - updates <- materializedNewMap.toList.parTraverse { kv => - setVarMapValueEntry(kv._1, kv._2, txnVarMap) - } - } yield (deletions ::: updates).flatten.toMap - - (log.get(txnVarMap.runtimeId) match { - case Some(entry) => - individualEntries.map { entries => - this.copy( - (log ++ entries) + (txnVarMap.runtimeId -> entry - .asInstanceOf[TxnLogEntry[Map[K, V]]] - .set(materializedNewMap)) - ) - } - case _ => - for { - v <- txnVarMap.get - entries <- individualEntries - } yield this.copy( - (log ++ entries) + (txnVarMap.runtimeId -> TxnLogUpdateVarMapStructureEntry( - v, - materializedNewMap, - txnVarMap - )) + newMap.flatMap { materializedNewMap => + val individualEntries: F[Map[TxnVarRuntimeId, TxnLogEntry[_]]] = for { + currentMap <- extractMap(txnVarMap, log) + deletions <- + (currentMap.keySet -- materializedNewMap.keySet).toList.traverse { + ks => + deleteVarMapValueEntry(ks, txnVarMap) + } + updates <- materializedNewMap.toList.traverse { kv => + setVarMapValueEntry(kv._1, kv._2, txnVarMap) + } + } yield (deletions ::: updates).flatten.toMap + + (log.get(txnVarMap.runtimeId) match { + case Some(entry) => + individualEntries.map { entries => + this.copy( + (log ++ entries) + (txnVarMap.runtimeId -> entry + .asInstanceOf[TxnLogEntry[Map[K, V]]] + .set(materializedNewMap)) ) - }).map(_.asInstanceOf[TxnLog]) - case Failure(exception) => - raiseError(exception) + } + case _ => + for { + v <- txnVarMap.get + entries <- individualEntries + } yield this.copy( + (log ++ entries) + (txnVarMap.runtimeId -> TxnLogUpdateVarMapStructureEntry( + v, + materializedNewMap, + txnVarMap + )) + ) + }).map(_.asInstanceOf[TxnLog]) } + .handleErrorWith(raiseError) override private[stm] def setVarMapValue[K, V]( - key: () => K, - newValue: () => V, + key: F[K], + newValue: F[V], txnVarMap: TxnVarMap[F, K, V] ): F[TxnLog] = - Try((key(), newValue())) match { - case Success((materializedKey, materializedNewValue)) => - txnVarMap - .getTxnVar(materializedKey) - .flatMap { - case Some(txnVar) => - log.get(txnVar.runtimeId) match { - case Some(entry) => - Async[F].delay( - this.copy( - log + (txnVar.runtimeId -> entry - .asInstanceOf[TxnLogEntry[Option[V]]] - .set(Some(materializedNewValue))) - ) - ) - case None => - txnVar.get.map { v => - if (v != materializedNewValue) { - this.copy( - log + (txnVar.runtimeId -> TxnLogUpdateVarMapEntry( - materializedKey, - Some(v), - Some(materializedNewValue), - txnVarMap - )) - ) - } else { - this - } + (for { + materializedKey <- key + materializedNewValue <- newValue + result <- txnVarMap + .getTxnVar(materializedKey) + .flatMap { + case Some(txnVar) => + log.get(txnVar.runtimeId) match { + case Some(entry) => + Async[F].delay( + this.copy( + log + (txnVar.runtimeId -> entry + .asInstanceOf[TxnLogEntry[Option[V]]] + .set(Some(materializedNewValue))) + ) + ) + case None => + txnVar.get.map { v => + if (v != materializedNewValue) { + this.copy( + log + (txnVar.runtimeId -> TxnLogUpdateVarMapEntry( + materializedKey, + Some(v), + Some(materializedNewValue), + txnVarMap + )) + ) + } else { + this + } + } + } + case None => + // The txnVar may have been set to be created in this transaction + txnVarMap.getRuntimeId(materializedKey).map { rids => + rids + .flatMap(rid => log.get(rid).map((rid, _))) match { + case (rid, entry) :: _ => + this.copy( + log + (rid -> entry + .asInstanceOf[TxnLogEntry[Option[V]]] + .set(Some(materializedNewValue))) + ) + case _ => + this.copy( + log + (rids.head -> TxnLogUpdateVarMapEntry( + materializedKey, + None, + Some(materializedNewValue), + txnVarMap + )) + ) + } + } } - } - case None => - // The txnVar may have been set to be created in this transaction - txnVarMap.getRuntimeId(materializedKey).map { rids => - rids.flatMap(rid => log.get(rid).map((rid, _))) match { - case (rid, entry) :: _ => - this.copy( - log + (rid -> entry - .asInstanceOf[TxnLogEntry[Option[V]]] - .set(Some(materializedNewValue))) - ) - case _ => - this.copy( - log + (rids.head -> TxnLogUpdateVarMapEntry( - materializedKey, - None, - Some(materializedNewValue), - txnVarMap - )) - ) - } - } - } - .map(_.asInstanceOf[TxnLog]) - case Failure(exception) => - raiseError(exception) - } + .map(_.asInstanceOf[TxnLog]) + } yield result) + .handleErrorWith(raiseError) override private[stm] def modifyVarMapValue[K, V]( - key: () => K, - f: V => V, + key: F[K], + f: V => F[V], txnVarMap: TxnVarMap[F, K, V] ): F[TxnLog] = - Try(key()) match { - case Success(materializedKey) => - txnVarMap - .getTxnVar(materializedKey) - .flatMap { - case Some(txnVar) => - (log.get(txnVar.runtimeId) match { - case Some(entry) => - Async[F].delay { - entry.get.asInstanceOf[Option[V]] match { - case Some(v) => - this.copy( - log + (txnVar.runtimeId -> entry - .asInstanceOf[TxnLogEntry[Option[V]]] - .set(Some(f(v)))) - ) - case None => - TxnLogError( - new RuntimeException( - s"Key $key not found for modification" - ) + key.flatMap { materializedKey => + txnVarMap + .getTxnVar(materializedKey) + .flatMap { + case Some(txnVar) => + (log.get(txnVar.runtimeId) match { + case Some(entry) => + entry.get.asInstanceOf[Option[V]] match { + case Some(v) => + f(v).map { innerResult => + this.copy( + log + (txnVar.runtimeId -> entry + .asInstanceOf[TxnLogEntry[Option[V]]] + .set(Some(innerResult))) + ) + } + case None => + Async[F].delay { + TxnLogError( + new RuntimeException( + s"Key $materializedKey not found for modification" ) + ) } - } - case None => - txnVar.get.map { v => - this.copy( - log + (txnVar.runtimeId -> TxnLogUpdateVarMapEntry( - materializedKey, - Some(v), - Some(f(v)), - txnVarMap - )) - ) - } - }).map(_.asInstanceOf[TxnLog]) - case None => - // The txnVar may have been set to be created in this transaction - txnVarMap - .getRuntimeId(materializedKey) - .map { rids => - rids.flatMap(rid => log.get(rid).map((rid, _))) match { - case (rid, entry) :: _ => - val castEntry: TxnLogEntry[Option[V]] = - entry.asInstanceOf[TxnLogEntry[Option[V]]] - castEntry.get match { - case Some(v) => + } + case None => + for { + v <- txnVar.get + innerResult <- f(v) + } yield this.copy( + log + (txnVar.runtimeId -> TxnLogUpdateVarMapEntry( + materializedKey, + Some(v), + Some(innerResult), + txnVarMap + )) + ) + }).map(_.asInstanceOf[TxnLog]) + case None => + // The txnVar may have been set to be created in this transaction + txnVarMap + .getRuntimeId(materializedKey) + .flatMap { rids => + (rids.flatMap(rid => log.get(rid).map((rid, _))) match { + case (rid, entry) :: _ => + val castEntry: TxnLogEntry[Option[V]] = + entry.asInstanceOf[TxnLogEntry[Option[V]]] + castEntry.get match { + case Some(v) => + f(v).map { innerResult => this.copy( log + (rid -> castEntry - .set(Some(f(v)))) + .set(Some(innerResult))) ) - case None => + } + case None => + Async[F].delay { TxnLogError( new RuntimeException( - s"Key $key not found for modification" + s"Key $materializedKey not found for modification" ) ) - } - case _ => + } + } + case _ => + Async[F].delay { TxnLogError( new RuntimeException( - s"Key $key not found for modification" + s"Key $materializedKey not found for modification" ) ) - } - } - .map(_.asInstanceOf[TxnLog]) - } - case Failure(exception) => - raiseError(exception) - } + } + }).map(_.asInstanceOf[TxnLog]) + } + } + }.handleErrorWith(raiseError) override private[stm] def deleteVarMapValue[K, V]( - key: () => K, + key: F[K], txnVarMap: TxnVarMap[F, K, V] ): F[TxnLog] = - Try(key()) match { - case Success(materializedKey) => - for { - oTxnVar <- txnVarMap.getTxnVar(materializedKey) - result <- oTxnVar match { - case Some(txnVar) => - log.get(txnVar.runtimeId) match { - case Some(entry) => - Async[F].delay( - this.copy( - log + (txnVar.runtimeId -> entry - .asInstanceOf[TxnLogEntry[Option[V]]] - .set(None)) - ) - ) - case None => - for { - txnVal <- txnVar.get - } yield this.copy( - log + (txnVar.runtimeId -> - TxnLogUpdateVarMapEntry(materializedKey, - Some(txnVal), - None, - txnVarMap - )) - ) - } + (for { + materializedKey <- key + oTxnVar <- txnVarMap.getTxnVar(materializedKey) + result <- oTxnVar match { + case Some(txnVar) => + log.get(txnVar.runtimeId) match { + case Some(entry) => + Async[F].delay( + this.copy( + log + (txnVar.runtimeId -> entry + .asInstanceOf[TxnLogEntry[Option[V]]] + .set(None)) + ) + ) case None => for { - rids <- txnVarMap.getRuntimeId( - materializedKey - ) - } yield rids.flatMap(rid => - log.get(rid).map((rid, _)) - ) match { - case (rid, entry) :: _ => - this.copy( - log + (rid -> entry - .asInstanceOf[TxnLogEntry[Option[V]]] - .set(None)) - ) - case _ => // Throw error to be consistent with read behaviour - TxnLogError { - new RuntimeException( - s"Tried to remove non-existent key $key in transactional map" + txnVal <- txnVar.get + } yield this.copy( + log + (txnVar.runtimeId -> + TxnLogUpdateVarMapEntry(materializedKey, + Some(txnVal), + None, + txnVarMap + )) + ) + } + case None => + for { + rids <- txnVarMap.getRuntimeId( + materializedKey ) - } + } yield rids.flatMap(rid => + log.get(rid).map((rid, _)) + ) match { + case (rid, entry) :: _ => + this.copy( + log + (rid -> entry + .asInstanceOf[TxnLogEntry[Option[V]]] + .set(None)) + ) + case _ => // Throw error to be consistent with read behaviour + TxnLogError { + new RuntimeException( + s"Tried to remove non-existent key $materializedKey in transactional map" + ) } } - } yield result - case Failure(exception) => - raiseError(exception) - } + } + } yield result.asInstanceOf[TxnLog]) + .handleErrorWith(raiseError) override private[stm] def raiseError(ex: Throwable): F[TxnLog] = Async[F].delay(TxnLogError(ex)) @@ -1029,13 +1018,13 @@ private[stm] abstract class TxnLogContext[F[_]: Async] { } override private[stm] lazy val idClosure: F[IdClosure] = - log.values.toList.parTraverse { entry => + log.values.toList.traverse { entry => entry.idClosure }.map(_.reduce(_ mergeWith _)) override private[stm] def withLock[A](fa: F[A]): F[A] = for { - locks <- log.values.toList.parTraverse(_.lock) + locks <- log.values.toList.traverse(_.lock) result <- locks.toSet.flatten .foldLeft(Resource.eval(Async[F].unit))((i, j) => i >> j.permit) @@ -1043,7 +1032,7 @@ private[stm] abstract class TxnLogContext[F[_]: Async] { } yield result override private[stm] lazy val commit: F[Unit] = - log.values.toList.parTraverse(_.commit).void + log.values.toList.traverse(_.commit).void override private[stm] lazy val getRetrySignal : F[Option[Deferred[F, Unit]]] = @@ -1098,8 +1087,8 @@ private[stm] abstract class TxnLogContext[F[_]: Async] { for { retrySignal <- Deferred[F, Unit] registerRetries <- - validLog.log.values.toList.parTraverse(_.getRegisterRetry) - _ <- registerRetries.parTraverse(rr => rr(retrySignal)) + validLog.log.values.toList.traverse(_.getRegisterRetry) + _ <- registerRetries.traverse(rr => rr(retrySignal)) } yield Some(retrySignal) override private[stm] lazy val scheduleRetry = diff --git a/src/main/scala/bengal/stm/runtime/TxnRuntimeContext.scala b/src/main/scala/bengal/stm/runtime/TxnRuntimeContext.scala index 4a48ff2..d3940ef 100644 --- a/src/main/scala/bengal/stm/runtime/TxnRuntimeContext.scala +++ b/src/main/scala/bengal/stm/runtime/TxnRuntimeContext.scala @@ -29,9 +29,11 @@ import cats.syntax.all._ import scala.collection.mutable.{ListBuffer, Map => MutableMap} import scala.concurrent.duration.FiniteDuration -private[stm] abstract class TxnRuntimeContext[F[_]: Async] - extends TxnCompilerContext[F] { - this: TxnAdtContext[F] => +private[stm] trait TxnRuntimeContext[F[_]] { + this: AsyncImplicits[F] + with TxnCompilerContext[F] + with TxnLogContext[F] + with TxnAdtContext[F] => private[stm] val txnIdGen: Ref[F, TxnId] private[stm] val txnVarIdGen: Ref[F, TxnVarId] diff --git a/src/main/scala/bengal/stm/syntax/all/package.scala b/src/main/scala/bengal/stm/syntax/all/package.scala index 2d4336a..d6bd55b 100644 --- a/src/main/scala/bengal/stm/syntax/all/package.scala +++ b/src/main/scala/bengal/stm/syntax/all/package.scala @@ -30,8 +30,14 @@ package object all { def set(newValue: => V): Txn[Unit] = STM[F].setTxnVar(newValue, txnVar) + def setF(newValue: F[V]): Txn[Unit] = + STM[F].setTxnVarF(newValue, txnVar) + def modify(f: V => V): Txn[Unit] = STM[F].modifyTxnVar(f, txnVar) + + def modifyF(f: V => F[V]): Txn[Unit] = + STM[F].modifyTxnVarF(f, txnVar) } implicit class TxnVarMapOps[F[_]: STM, K, V](txnVarMap: TxnVarMap[F, K, V]) { @@ -42,18 +48,30 @@ package object all { def set(newValueMap: => Map[K, V]): Txn[Unit] = STM[F].setTxnVarMap(newValueMap, txnVarMap) + def set(newValueMap: F[Map[K, V]]): Txn[Unit] = + STM[F].setTxnVarMapF(newValueMap, txnVarMap) + def modify(f: Map[K, V] => Map[K, V]): Txn[Unit] = STM[F].modifyTxnVarMap(f, txnVarMap) + def modifyF(f: Map[K, V] => F[Map[K, V]]): Txn[Unit] = + STM[F].modifyTxnVarMapF(f, txnVarMap) + def get(key: => K): Txn[Option[V]] = STM[F].getTxnVarMapValue(key, txnVarMap) def set(key: => K, newValue: => V): Txn[Unit] = STM[F].setTxnVarMapValue(key, newValue, txnVarMap) + def setF(key: => K, newValue: F[V]): Txn[Unit] = + STM[F].setTxnVarMapValueF(key, newValue, txnVarMap) + def modify(key: => K, f: V => V): Txn[Unit] = STM[F].modifyTxnVarMapValue(key, f, txnVarMap) + def modifyF(key: => K, f: V => F[V]): Txn[Unit] = + STM[F].modifyTxnVarMapValueF(key, f, txnVarMap) + def remove(key: => K): Txn[Unit] = STM[F].removeTxnVarMapValue(key, txnVarMap) } @@ -65,5 +83,8 @@ package object all { def handleErrorWith(f: Throwable => Txn[V]): Txn[V] = STM[F].handleErrorWithInternal(txn)(f) + + def handleErrorWithF(f: Throwable => F[Txn[V]]): Txn[V] = + STM[F].handleErrorWithInternalF(txn)(f) } } diff --git a/src/test/scala/model/TxnVarMapSpec.scala b/src/test/scala/model/TxnVarMapSpec.scala index cce59db..054e3e5 100644 --- a/src/test/scala/model/TxnVarMapSpec.scala +++ b/src/test/scala/model/TxnVarMapSpec.scala @@ -32,7 +32,7 @@ class TxnVarMapSpec with AsyncIOSpec with Matchers with EitherValues { - val baseMap = Map("foo" -> 42, "bar" -> 27, "baz" -> 18) + val baseMap: Map[String, Int] = Map("foo" -> 42, "bar" -> 27, "baz" -> 18) "TxnVarMap.get" - { "return the value of a transactional map" in { @@ -82,12 +82,12 @@ class TxnVarMapSpec } yield result).asserting(_ shouldBe Some(42)) } - "throw an error if key isn't present" in { + "return None if key isn't present" in { (for { implicit0(stm: STM[IO]) <- STM.runtime[IO] tVarMap <- TxnVarMap.of(baseMap) - result <- tVarMap.get("foobar").commit.attempt - } yield result).asserting(_.left.value shouldBe a[RuntimeException]) + result <- tVarMap.get("foobar").commit + } yield result).asserting(_ shouldBe None) } "return None if the key is deleted in the current transaction" in {