diff --git a/core/shared/src/main/scala/coop/Deferred.scala b/core/shared/src/main/scala/coop/Deferred.scala new file mode 100644 index 0000000..a101d2d --- /dev/null +++ b/core/shared/src/main/scala/coop/Deferred.scala @@ -0,0 +1,51 @@ +/* + * Copyright 2020 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package coop + +import cats.{Applicative, Monad} +import cats.free.FreeT +import cats.syntax.all._ + +import ThreadF._ + +final class Deferred[A] private[coop] (private[coop] val monitorId: MonitorId) { self => + def tryGet[F[_]: Applicative]: ThreadT[F, Option[A]] = + FreeT.liftF(TryGetDeferred(this, identity[Option[A]])) + + def get[F[_]: Monad]: ThreadT[F, A] = + tryGet[F].flatMap { + case Some(a) => Applicative[ThreadT[F, *]].pure(a) + case None => ThreadT.await(monitorId) >> get[F] + } + + def complete[F[_]: Monad](a: A): ThreadT[F, Unit] = + FreeT.liftF(CompleteDeferred(this, a, () => ()): ThreadF[Unit]) >> ThreadT.notify[F](monitorId) + + def apply[F[_]: Monad]: DeferredPartiallyApplied[F] = + new DeferredPartiallyApplied[F] + + class DeferredPartiallyApplied[F[_]: Monad] { + def tryGet: ThreadT[F, Option[A]] = self.tryGet + def get: ThreadT[F, A] = self.get + def complete(a: A): ThreadT[F, Unit] = self.complete(a) + } +} + +object Deferred { + def apply[F[_]: Applicative, A]: ThreadT[F, Deferred[A]] = + ThreadT.monitor[F].flatMap(id => FreeT.liftF(MkDeferred(id, identity[Deferred[A]]))) +} diff --git a/core/shared/src/main/scala/coop/Ref.scala b/core/shared/src/main/scala/coop/Ref.scala new file mode 100644 index 0000000..f5b2f6c --- /dev/null +++ b/core/shared/src/main/scala/coop/Ref.scala @@ -0,0 +1,62 @@ +/* + * Copyright 2020 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package coop + +import cats.Applicative +import cats.free.FreeT + +import ThreadF._ + +final class Ref[A] private[coop] (private[coop] val monitorId: MonitorId) { self => + def get[F[_]: Applicative]: ThreadT[F, A] = + modify(a => (a, a)) + + def set[F[_]: Applicative](a: A): ThreadT[F, Unit] = + modify(_ => (a, ())) + + def modify[F[_]: Applicative, B](f: A => (A, B)): ThreadT[F, B] = + FreeT.liftF(ModifyRef(this, f, identity[B])) + + def getAndSet[F[_]: Applicative](a: A): ThreadT[F, A] = + modify(oldA => (a, oldA)) + + def getAndUpdate[F[_]: Applicative](f: A => A): ThreadT[F, A] = + modify(a => (f(a), a)) + + def updateAndGet[F[_]: Applicative](f: A => A): ThreadT[F, A] = + modify { a => + val newA = f(a) + (newA, newA) + } + + def apply[F[_]: Applicative]: RefPartiallyApplied[F] = + new RefPartiallyApplied[F] + + class RefPartiallyApplied[F[_]: Applicative] { + val get: ThreadT[F, A] = self.get + def set(a: A): ThreadT[F, Unit] = self.set(a) + def modify[B](f: A => (A, B)): ThreadT[F, B] = self.modify(f) + def getAndSet(a: A): ThreadT[F, A] = self.getAndSet(a) + def getAndUpdate(f: A => A): ThreadT[F, A] = self.getAndUpdate(f) + def updateAndGet(f: A => A): ThreadT[F, A] = self.updateAndGet(f) + } +} + +object Ref { + def of[F[_]: Applicative, A](a: A): ThreadT[F, Ref[A]] = + ThreadT.monitor[F].flatMap(id => FreeT.liftF(MkRef(a, id, identity[Ref[A]]))) +} diff --git a/core/shared/src/main/scala/coop/ThreadF.scala b/core/shared/src/main/scala/coop/ThreadF.scala index fa827ba..4db2b2d 100644 --- a/core/shared/src/main/scala/coop/ThreadF.scala +++ b/core/shared/src/main/scala/coop/ThreadF.scala @@ -35,6 +35,13 @@ object ThreadF { case Annotate(text, results) => Annotate(text, () => f(results())) case Indent(results) => Indent(() => f(results())) case Dedent(results) => Dedent(() => f(results())) + + case MkRef(a, id, body) => MkRef(a, id, body.andThen(f)) + case ModifyRef(ref, modF, body) => ModifyRef(ref, modF, body.andThen(f)) + + case MkDeferred(id, body) => MkDeferred(id, body.andThen(f)) + case TryGetDeferred(deferred, body) => TryGetDeferred(deferred, body.andThen(f)) + case CompleteDeferred(deferred, a, body) => CompleteDeferred(deferred, a, () => f(body())) } } @@ -50,6 +57,14 @@ object ThreadF { final case class Indent[A](results: () => A) extends ThreadF[A] final case class Dedent[A](results: () => A) extends ThreadF[A] + final case class MkRef[A, B](a: A, id: MonitorId, body: Ref[A] => B) extends ThreadF[B] + final case class ModifyRef[A, B, C](ref: Ref[A], f: A => (A, B), body: B => C) extends ThreadF[C] + + final case class MkDeferred[A, B](id: MonitorId, body: Deferred[A] => B) extends ThreadF[B] + final case class TryGetDeferred[A, B](deferred: Deferred[A], body: Option[A] => B) extends ThreadF[B] + final case class CompleteDeferred[A, B](deferred: Deferred[A], a: A, body: () => B) extends ThreadF[B] + // an opaque fresh id - final class MonitorId private[coop] () + final case class MonitorId private[coop] (private[coop] val id: Int) + private object MonitorId } diff --git a/core/shared/src/main/scala/coop/ThreadT.scala b/core/shared/src/main/scala/coop/ThreadT.scala index 1c61d2a..b90b273 100644 --- a/core/shared/src/main/scala/coop/ThreadT.scala +++ b/core/shared/src/main/scala/coop/ThreadT.scala @@ -63,44 +63,70 @@ object ThreadT { case class LoopState( head: Option[() => ThreadT[M, _]], work: Queue[() => ThreadT[M, _]], - locks: Map[MonitorId, Queue[() => ThreadT[M, _]]]) + monitorCount: Int, + locks: Map[MonitorId, Queue[() => ThreadT[M, _]]], + refs: Map[MonitorId, _], + deferreds: Map[MonitorId, _] + ) - Monad[M].tailRecM(LoopState(Some(() => main), Queue.empty, Map.empty)) { ls => - val LoopState(head, work, locks) = ls + Monad[M].tailRecM(LoopState(Some(() => main), Queue.empty, 0, Map.empty, Map.empty, Map.empty)) { ls => + val LoopState(head, work, count, locks, refs, deferreds) = ls head.tupleRight(work).orElse(work.dequeueOption) match { case Some((head, tail)) => head().resume map { case Left(Fork(left, right)) => - Left(LoopState(Some(left), tail.enqueue(right), locks)) + Left(LoopState(Some(left), tail.enqueue(right), count, locks, refs, deferreds)) case Left(Cede(results)) => val tail2 = tail.enqueue(results) - Left(LoopState(None, tail2, locks)) + Left(LoopState(None, tail2, count, locks, refs, deferreds)) case Left(Done) | Right(_) => - Left(LoopState(None, tail, locks)) + Left(LoopState(None, tail, count, locks, refs, deferreds)) case Left(Monitor(f)) => - val id = new MonitorId() - Left(LoopState(Some(() => f(id)), tail, locks + (id -> Queue.empty))) + val id = new MonitorId(count) + Left(LoopState(Some(() => f(id)), tail, count + 1, locks + (id -> Queue.empty), refs, deferreds)) case Left(Await(id, results)) => - Left(LoopState(None, tail, locks.updated(id, locks(id).enqueue(results)))) + Left(LoopState(None, tail, count, locks.updated(id, locks(id).enqueue(results)), refs, deferreds)) case Left(Notify(id, results)) => // enqueueAll was added in 2.13 val tail2 = locks(id).foldLeft(tail)(_.enqueue(_)) - Left(LoopState(None, tail2.enqueue(results), locks.updated(id, Queue.empty))) + Left(LoopState(None, tail2.enqueue(results), count, locks.updated(id, Queue.empty), refs, deferreds)) case Left(Annotate(_, results)) => - Left(LoopState(Some(results), tail, locks)) + Left(LoopState(Some(results), tail, count, locks, refs, deferreds)) case Left(Indent(results)) => - Left(LoopState(Some(results), tail, locks)) + Left(LoopState(Some(results), tail, count, locks, refs, deferreds)) case Left(Dedent(results)) => - Left(LoopState(Some(results), tail, locks)) + Left(LoopState(Some(results), tail, count, locks, refs, deferreds)) + + case Left(mkref: MkRef[a, b]) => + val a = mkref.a + val ref = new Ref[a](mkref.id) + Left(LoopState(Some(() => mkref.body(ref)), tail, count, locks, refs + (mkref.id -> a), deferreds)) + + case Left(modifyRef: ModifyRef[a, b, c]) => + val a = refs(modifyRef.ref.monitorId).asInstanceOf[a] + val (newA, b) = modifyRef.f(a) + Left(LoopState(Some(() => modifyRef.body(b)), tail, count, locks, refs.updated(modifyRef.ref.monitorId, newA), deferreds)) + + case Left(mkDeferred: MkDeferred[a, b]) => + val deferred = new Deferred[a](mkDeferred.id) + Left(LoopState(Some(() => mkDeferred.body(deferred)), tail, count, locks, refs, deferreds)) + + case Left(tryGetDeferred: TryGetDeferred[a, b]) => + val optA = deferreds.get(tryGetDeferred.deferred.monitorId).map(_.asInstanceOf[a]) + Left(LoopState(Some(() => tryGetDeferred.body(optA)), tail, count, locks, refs, deferreds)) + + case Left(completeDeferred: CompleteDeferred[a, b]) => + val newA = deferreds.get(completeDeferred.deferred.monitorId).map(_.asInstanceOf[a]).getOrElse(completeDeferred.a) + Left(LoopState(Some(() => completeDeferred.body()), tail, count, locks, refs, deferreds.updated(completeDeferred.deferred.monitorId, newA))) } // if we have outstanding awaits but no active fibers, then we're deadlocked @@ -157,15 +183,23 @@ object ThreadT { def drawId(id: MonitorId): String = "0x" + id.hashCode.toHexString.toUpperCase + def drawRef(ref: Ref[_], a: Any): String = "Ref(id = " + drawId(ref.monitorId) + ") =" + a.toString + + def drawDeferred(deferred: Deferred[_]): String = "Deferred(id = " + drawId(deferred.monitorId) + ")" + case class LoopState( target: ThreadT[M, A], acc: String, indent: List[Boolean], - init: Boolean = false) + init: Boolean = false, + monitorCount: Int, + refs: Map[MonitorId, _], + deferreds: Map[MonitorId, _] + ) - def loop(target: ThreadT[M, A], acc: String, indent: List[Boolean], init: Boolean): M[String] = { - Monad[M].tailRecM(LoopState(target, acc, indent, init)) { ls => - val LoopState(target, acc0, indent, init) = ls + def loop(target: ThreadT[M, A], acc: String, indent: List[Boolean], init: Boolean, monitorCount: Int, refs: Map[MonitorId, _], deferreds: Map[MonitorId, _]): M[String] = { + Monad[M].tailRecM(LoopState(target, acc, indent, init, monitorCount, refs, deferreds)) { ls => + val LoopState(target, acc0, indent, init, count, refs, deferreds) = ls val junc0 = if (init) InverseTurnRight else Junction val trailing = if (indent != Nil) "\n" + drawIndent(indent, "") else "" @@ -187,42 +221,68 @@ object ThreadT { case Left(Fork(left, right)) => val leading = drawIndent(indent, junc + " Fork") + "\n" + drawIndent(indent, ForkStr) - loop(right(), "", true :: indent, false) map { rightStr => + loop(right(), "", true :: indent, false, count, refs, deferreds) map { rightStr => val acc2 = acc + leading + "\n" + rightStr + "\n" - LoopState(left(), acc2, indent, false).asLeft[String] + LoopState(left(), acc2, indent, false, count, refs, deferreds).asLeft[String] } case Left(Cede(results)) => val acc2 = acc + drawIndent(indent, junc + " Cede") + "\n" - LoopState(results(), acc2, indent, false).asLeft[String].pure[M] + LoopState(results(), acc2, indent, false, count, refs, deferreds).asLeft[String].pure[M] case Left(Done) => (acc + drawIndent(indent, TurnRight + " Done" + trailing)).asRight[LoopState].pure[M] case Left(Monitor(f)) => - val id = new MonitorId - LoopState(f(id), acc, indent, init).asLeft[String].pure[M] // don't render the creation + val id = new MonitorId(count) + LoopState(f(id), acc, indent, init, count + 1, refs, deferreds).asLeft[String].pure[M] // don't render the creation case Left(Await(id, results)) => val acc2 = acc + drawIndent(indent, junc + " Await ") + drawId(id) + "\n" - LoopState(results(), acc2, indent, false).asLeft[String].pure[M] + LoopState(results(), acc2, indent, false, count, refs, deferreds).asLeft[String].pure[M] case Left(Notify(id, results)) => val acc2 = acc + drawIndent(indent, junc + " Notify ") + drawId(id) + "\n" - LoopState(results(), acc2, indent, false).asLeft[String].pure[M] + LoopState(results(), acc2, indent, false, count, refs, deferreds).asLeft[String].pure[M] case Left(Annotate(name, results)) => val acc2 = acc + drawIndent(indent, junc + s" $name") + "\n" - LoopState(results(), acc2, indent, false).asLeft[String].pure[M] + LoopState(results(), acc2, indent, false, count, refs, deferreds).asLeft[String].pure[M] case Left(Indent(results)) => val acc2 = acc + drawIndent(indent, IndentStr) + "\n" - LoopState(results(), acc2, false :: indent, false).asLeft[String].pure[M] + LoopState(results(), acc2, false :: indent, false, count, refs, deferreds).asLeft[String].pure[M] case Left(Dedent(results)) => val indent2 = indent.tail val acc2 = acc + drawIndent(indent2, DedentStr) + "\n" - LoopState(results(), acc2, indent2, false).asLeft[String].pure[M] + LoopState(results(), acc2, indent2, false, count, refs, deferreds).asLeft[String].pure[M] + + case Left(mkRef: MkRef[a, b]) => + val ref = new Ref[a](mkRef.id) + val acc2 = acc + drawIndent(indent, junc + " Create ref ") + drawRef(ref, mkRef.a) + "\n" + LoopState(mkRef.body(ref), acc2, indent, init, count, refs + (ref.monitorId -> mkRef.a), deferreds).asLeft[String].pure[M] + + case Left(modifyRef: ModifyRef[a, b, c]) => + val a = refs(modifyRef.ref.monitorId).asInstanceOf[a] + val (newA, b) = modifyRef.f(a) + val acc2 = acc + drawIndent(indent, junc + " Modify ref ") + drawRef(modifyRef.ref, newA) + ", produced " + b.toString + "\n" + LoopState(modifyRef.body(b), acc2, indent, init, count, refs.updated(modifyRef.ref.monitorId, newA), deferreds).asLeft[String].pure[M] + + case Left(mkDeferred: MkDeferred[a, b]) => + val deferred = new Deferred[a](mkDeferred.id) + val acc2 = acc + drawIndent(indent, junc + " Create deferred ") + drawDeferred(deferred) + "\n" + LoopState(mkDeferred.body(deferred), acc2, indent, init, count, refs, deferreds).asLeft[String].pure[M] + + case Left(tryGetDeferred: TryGetDeferred[a, b]) => + val optA = deferreds.get(tryGetDeferred.deferred.monitorId).map(_.asInstanceOf[a]) + val acc2 = acc + drawIndent(indent, junc + " Try get deferred ") + drawDeferred(tryGetDeferred.deferred) + " = " + optA.toString + "\n" + LoopState(tryGetDeferred.body(optA), acc2, indent, init, count, refs, deferreds).asLeft[String].pure[M] + + case Left(completeDeferred: CompleteDeferred[a, b]) => + val newA = deferreds.get(completeDeferred.deferred.monitorId).map(_.asInstanceOf[a]).getOrElse(completeDeferred.a) + val acc2 = acc + drawIndent(indent, junc + " Complete deferred ") + drawDeferred(completeDeferred.deferred) + " with value " + newA.toString + "\n" + LoopState(completeDeferred.body(), acc2, indent, init, count, refs, deferreds.updated(completeDeferred.deferred.monitorId, newA)).asLeft[String].pure[M] case Right(a) => (acc + drawIndent(indent, TurnRight + " Pure " + a.show + trailing)).asRight[LoopState].pure[M] @@ -231,6 +291,6 @@ object ThreadT { } } - loop(target, "", Nil, true) + loop(target, "", Nil, true, 0, Map.empty, Map.empty) } } diff --git a/core/shared/src/test/scala/coop/DeferredSpecs.scala b/core/shared/src/test/scala/coop/DeferredSpecs.scala new file mode 100644 index 0000000..82792a7 --- /dev/null +++ b/core/shared/src/test/scala/coop/DeferredSpecs.scala @@ -0,0 +1,96 @@ +/* + * Copyright 2020 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package coop + +import cats.data.State +import cats.kernel.Monoid +import cats.syntax.all._ + +import org.specs2.mutable.Specification + +class DeferredSpecs extends Specification { + "Deferred" should { + "complete" in { + val eff = for { + d <- Deferred[State[Int, *], Int] + dp = d[State[Int, *]] + _ <- dp.complete(10) + v <- dp.get + _ <- ThreadT.liftF(State.set(v)) + } yield () + + runToCompletionEmpty(eff) mustEqual 10 + } + + "complete only once" in { + val eff = for { + d <- Deferred[State[Int, *], Int] + dp = d[State[Int, *]] + _ <- dp.complete(10) + _ <- dp.complete(20) + v <- dp.get + _ <- ThreadT.liftF(State.set(v)) + } yield () + + runToCompletionEmpty(eff) mustEqual 10 + } + + "tryGet returns None for unset Deferred" in { + val eff = for { + d <- Deferred[State[Option[Int], *], Int] + dp = d[State[Option[Int], *]] + v <- dp.tryGet + _ <- ThreadT.liftF(State.set(v)) + } yield () + + runToCompletionEmpty(eff) mustEqual None + } + + "tryGet returns Some for set Deferred" in { + val eff = for { + d <- Deferred[State[Option[Int], *], Int] + dp = d[State[Option[Int], *]] + _ <- dp.complete(10) + v <- dp.tryGet + _ <- ThreadT.liftF(State.set(v)) + } yield () + + runToCompletionEmpty(eff) mustEqual Some(10) + } + + "get blocks until set" in { + val eff = for { + state <- Ref.of[State[Int, *], Int](0) + modifyGate <- Deferred[State[Int, *], Unit] + readGate <- Deferred[State[Int, *], Unit] + _ <- ThreadT.start(modifyGate.get[State[Int, *]] *> state.updateAndGet[State[Int, *]](_ * 2) *> readGate.complete[State[Int, *]](())) + _ <- ThreadT.start(state.set[State[Int, *]](1) *> modifyGate.complete[State[Int, *]](())) + _ <- readGate.get[State[Int, *]] + v <- state.get[State[Int, *]] + _ <- ThreadT.liftF(State.set(v)) + } yield () + + runToCompletionEmpty(eff) mustEqual 2 + } + } + + def runToCompletionEmpty[S: Monoid](fa: ThreadT[State[S, *], _]): S = + runToCompletion(Monoid[S].empty, fa) + + def runToCompletion[S](init: S, fa: ThreadT[State[S, *], _]): S = + ThreadT.roundRobin(fa).runS(init).value +} diff --git a/core/shared/src/test/scala/coop/RefSpecs.scala b/core/shared/src/test/scala/coop/RefSpecs.scala new file mode 100644 index 0000000..2b94570 --- /dev/null +++ b/core/shared/src/test/scala/coop/RefSpecs.scala @@ -0,0 +1,81 @@ +/* + * Copyright 2020 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package coop + +import cats.data.State +import cats.kernel.Monoid + +import org.specs2.mutable.Specification + +class RefSpecs extends Specification { + "Ref" should { + "get and set successfully" in { + val eff = for { + ref <- Ref.of[State[(Int, Int), *], Int](5) + refp = ref[State[(Int, Int), *]] + v1 <- refp.getAndSet(10) + v2 <- refp.get + _ <- ThreadT.liftF(State.set((v1, v2))) + } yield () + + runToCompletionEmpty(eff) mustEqual ((5, 10)) + } + + "get and update successfully" in { + val eff = for { + ref <- Ref.of[State[(Int, Int), *], Int](5) + refp = ref[State[(Int, Int), *]] + v1 <- refp.getAndUpdate(_ * 2) + v2 <- refp.get + _ <- ThreadT.liftF(State.set((v1, v2))) + } yield () + + runToCompletionEmpty(eff) mustEqual ((5, 10)) + } + + "update and get successfully" in { + val eff = for { + ref <- Ref.of[State[(Int, Int), *], Int](5) + refp = ref[State[(Int, Int), *]] + v1 <- refp.updateAndGet(_ * 2) + v2 <- refp.get + _ <- ThreadT.liftF(State.set((v1, v2))) + } yield () + + runToCompletionEmpty(eff) mustEqual ((10, 10)) + } + + "set from a background thread" in { + val eff = for { + ref <- Ref.of[State[Int, *], Int](5) + refp = ref[State[Int, *]] + _ <- ThreadT.start(refp.set(10)) + _ <- ThreadT.cede[State[Int, *]] + v <- refp.get + _ <- ThreadT.liftF(State.set(v)) + } yield () + + runToCompletionEmpty(eff) mustEqual(10) + } + } + + def runToCompletionEmpty[S: Monoid](fa: ThreadT[State[S, *], _]): S = + runToCompletion(Monoid[S].empty, fa) + + def runToCompletion[S](init: S, fa: ThreadT[State[S, *], _]): S = + ThreadT.roundRobin(fa).runS(init).value +}