Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove all sources of impurity #31

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions core/shared/src/main/scala/coop/Deferred.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright 2020 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package coop

import cats.{Applicative, Monad}
import cats.free.FreeT
import cats.syntax.all._

import ThreadF._

final class Deferred[A] private[coop] (private[coop] val monitorId: MonitorId) { self =>
def tryGet[F[_]: Applicative]: ThreadT[F, Option[A]] =
FreeT.liftF(TryGetDeferred(this, identity[Option[A]]))

def get[F[_]: Monad]: ThreadT[F, A] =
tryGet[F].flatMap {
case Some(a) => Applicative[ThreadT[F, *]].pure(a)
case None => ThreadT.await(monitorId) >> get[F]
}

def complete[F[_]: Monad](a: A): ThreadT[F, Unit] =
FreeT.liftF(CompleteDeferred(this, a, () => ()): ThreadF[Unit]) >> ThreadT.notify[F](monitorId)

def apply[F[_]: Monad]: DeferredPartiallyApplied[F] =
new DeferredPartiallyApplied[F]

class DeferredPartiallyApplied[F[_]: Monad] {
def tryGet: ThreadT[F, Option[A]] = self.tryGet
def get: ThreadT[F, A] = self.get
def complete(a: A): ThreadT[F, Unit] = self.complete(a)
}
}

object Deferred {
def apply[F[_]: Applicative, A]: ThreadT[F, Deferred[A]] =
ThreadT.monitor[F].flatMap(id => FreeT.liftF(MkDeferred(id, identity[Deferred[A]])))
}
62 changes: 62 additions & 0 deletions core/shared/src/main/scala/coop/Ref.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright 2020 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package coop

import cats.Applicative
import cats.free.FreeT

import ThreadF._

final class Ref[A] private[coop] (private[coop] val monitorId: MonitorId) { self =>
def get[F[_]: Applicative]: ThreadT[F, A] =
modify(a => (a, a))

def set[F[_]: Applicative](a: A): ThreadT[F, Unit] =
modify(_ => (a, ()))

def modify[F[_]: Applicative, B](f: A => (A, B)): ThreadT[F, B] =
FreeT.liftF(ModifyRef(this, f, identity[B]))

def getAndSet[F[_]: Applicative](a: A): ThreadT[F, A] =
modify(oldA => (a, oldA))

def getAndUpdate[F[_]: Applicative](f: A => A): ThreadT[F, A] =
modify(a => (f(a), a))

def updateAndGet[F[_]: Applicative](f: A => A): ThreadT[F, A] =
modify { a =>
val newA = f(a)
(newA, newA)
}

def apply[F[_]: Applicative]: RefPartiallyApplied[F] =
new RefPartiallyApplied[F]

class RefPartiallyApplied[F[_]: Applicative] {
val get: ThreadT[F, A] = self.get
def set(a: A): ThreadT[F, Unit] = self.set(a)
def modify[B](f: A => (A, B)): ThreadT[F, B] = self.modify(f)
def getAndSet(a: A): ThreadT[F, A] = self.getAndSet(a)
def getAndUpdate(f: A => A): ThreadT[F, A] = self.getAndUpdate(f)
def updateAndGet(f: A => A): ThreadT[F, A] = self.updateAndGet(f)
}
}

object Ref {
def of[F[_]: Applicative, A](a: A): ThreadT[F, Ref[A]] =
ThreadT.monitor[F].flatMap(id => FreeT.liftF(MkRef(a, id, identity[Ref[A]])))
}
17 changes: 16 additions & 1 deletion core/shared/src/main/scala/coop/ThreadF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ object ThreadF {
case Annotate(text, results) => Annotate(text, () => f(results()))
case Indent(results) => Indent(() => f(results()))
case Dedent(results) => Dedent(() => f(results()))

case MkRef(a, id, body) => MkRef(a, id, body.andThen(f))
case ModifyRef(ref, modF, body) => ModifyRef(ref, modF, body.andThen(f))

case MkDeferred(id, body) => MkDeferred(id, body.andThen(f))
case TryGetDeferred(deferred, body) => TryGetDeferred(deferred, body.andThen(f))
case CompleteDeferred(deferred, a, body) => CompleteDeferred(deferred, a, () => f(body()))
}
}

Expand All @@ -50,6 +57,14 @@ object ThreadF {
final case class Indent[A](results: () => A) extends ThreadF[A]
final case class Dedent[A](results: () => A) extends ThreadF[A]

final case class MkRef[A, B](a: A, id: MonitorId, body: Ref[A] => B) extends ThreadF[B]
final case class ModifyRef[A, B, C](ref: Ref[A], f: A => (A, B), body: B => C) extends ThreadF[C]

final case class MkDeferred[A, B](id: MonitorId, body: Deferred[A] => B) extends ThreadF[B]
final case class TryGetDeferred[A, B](deferred: Deferred[A], body: Option[A] => B) extends ThreadF[B]
final case class CompleteDeferred[A, B](deferred: Deferred[A], a: A, body: () => B) extends ThreadF[B]

// an opaque fresh id
final class MonitorId private[coop] ()
final case class MonitorId private[coop] (private[coop] val id: Int)
private object MonitorId
}
116 changes: 88 additions & 28 deletions core/shared/src/main/scala/coop/ThreadT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,44 +63,70 @@ object ThreadT {
case class LoopState(
head: Option[() => ThreadT[M, _]],
work: Queue[() => ThreadT[M, _]],
locks: Map[MonitorId, Queue[() => ThreadT[M, _]]])
monitorCount: Int,
locks: Map[MonitorId, Queue[() => ThreadT[M, _]]],
refs: Map[MonitorId, _],
deferreds: Map[MonitorId, _]
)

Monad[M].tailRecM(LoopState(Some(() => main), Queue.empty, Map.empty)) { ls =>
val LoopState(head, work, locks) = ls
Monad[M].tailRecM(LoopState(Some(() => main), Queue.empty, 0, Map.empty, Map.empty, Map.empty)) { ls =>
val LoopState(head, work, count, locks, refs, deferreds) = ls

head.tupleRight(work).orElse(work.dequeueOption) match {
case Some((head, tail)) =>
head().resume map {
case Left(Fork(left, right)) =>
Left(LoopState(Some(left), tail.enqueue(right), locks))
Left(LoopState(Some(left), tail.enqueue(right), count, locks, refs, deferreds))

case Left(Cede(results)) =>
val tail2 = tail.enqueue(results)
Left(LoopState(None, tail2, locks))
Left(LoopState(None, tail2, count, locks, refs, deferreds))

case Left(Done) | Right(_) =>
Left(LoopState(None, tail, locks))
Left(LoopState(None, tail, count, locks, refs, deferreds))

case Left(Monitor(f)) =>
val id = new MonitorId()
Left(LoopState(Some(() => f(id)), tail, locks + (id -> Queue.empty)))
val id = new MonitorId(count)
Left(LoopState(Some(() => f(id)), tail, count + 1, locks + (id -> Queue.empty), refs, deferreds))

case Left(Await(id, results)) =>
Left(LoopState(None, tail, locks.updated(id, locks(id).enqueue(results))))
Left(LoopState(None, tail, count, locks.updated(id, locks(id).enqueue(results)), refs, deferreds))

case Left(Notify(id, results)) =>
// enqueueAll was added in 2.13
val tail2 = locks(id).foldLeft(tail)(_.enqueue(_))
Left(LoopState(None, tail2.enqueue(results), locks.updated(id, Queue.empty)))
Left(LoopState(None, tail2.enqueue(results), count, locks.updated(id, Queue.empty), refs, deferreds))

case Left(Annotate(_, results)) =>
Left(LoopState(Some(results), tail, locks))
Left(LoopState(Some(results), tail, count, locks, refs, deferreds))

case Left(Indent(results)) =>
Left(LoopState(Some(results), tail, locks))
Left(LoopState(Some(results), tail, count, locks, refs, deferreds))

case Left(Dedent(results)) =>
Left(LoopState(Some(results), tail, locks))
Left(LoopState(Some(results), tail, count, locks, refs, deferreds))

case Left(mkref: MkRef[a, b]) =>
val a = mkref.a
val ref = new Ref[a](mkref.id)
Left(LoopState(Some(() => mkref.body(ref)), tail, count, locks, refs + (mkref.id -> a), deferreds))

case Left(modifyRef: ModifyRef[a, b, c]) =>
val a = refs(modifyRef.ref.monitorId).asInstanceOf[a]
val (newA, b) = modifyRef.f(a)
Left(LoopState(Some(() => modifyRef.body(b)), tail, count, locks, refs.updated(modifyRef.ref.monitorId, newA), deferreds))

case Left(mkDeferred: MkDeferred[a, b]) =>
val deferred = new Deferred[a](mkDeferred.id)
Left(LoopState(Some(() => mkDeferred.body(deferred)), tail, count, locks, refs, deferreds))

case Left(tryGetDeferred: TryGetDeferred[a, b]) =>
val optA = deferreds.get(tryGetDeferred.deferred.monitorId).map(_.asInstanceOf[a])
Left(LoopState(Some(() => tryGetDeferred.body(optA)), tail, count, locks, refs, deferreds))

case Left(completeDeferred: CompleteDeferred[a, b]) =>
val newA = deferreds.get(completeDeferred.deferred.monitorId).map(_.asInstanceOf[a]).getOrElse(completeDeferred.a)
Left(LoopState(Some(() => completeDeferred.body()), tail, count, locks, refs, deferreds.updated(completeDeferred.deferred.monitorId, newA)))
}

// if we have outstanding awaits but no active fibers, then we're deadlocked
Expand Down Expand Up @@ -157,15 +183,23 @@ object ThreadT {

def drawId(id: MonitorId): String = "0x" + id.hashCode.toHexString.toUpperCase

def drawRef(ref: Ref[_], a: Any): String = "Ref(id = " + drawId(ref.monitorId) + ") =" + a.toString

def drawDeferred(deferred: Deferred[_]): String = "Deferred(id = " + drawId(deferred.monitorId) + ")"

case class LoopState(
target: ThreadT[M, A],
acc: String,
indent: List[Boolean],
init: Boolean = false)
init: Boolean = false,
monitorCount: Int,
refs: Map[MonitorId, _],
deferreds: Map[MonitorId, _]
)

def loop(target: ThreadT[M, A], acc: String, indent: List[Boolean], init: Boolean): M[String] = {
Monad[M].tailRecM(LoopState(target, acc, indent, init)) { ls =>
val LoopState(target, acc0, indent, init) = ls
def loop(target: ThreadT[M, A], acc: String, indent: List[Boolean], init: Boolean, monitorCount: Int, refs: Map[MonitorId, _], deferreds: Map[MonitorId, _]): M[String] = {
Monad[M].tailRecM(LoopState(target, acc, indent, init, monitorCount, refs, deferreds)) { ls =>
val LoopState(target, acc0, indent, init, count, refs, deferreds) = ls

val junc0 = if (init) InverseTurnRight else Junction
val trailing = if (indent != Nil) "\n" + drawIndent(indent, "") else ""
Expand All @@ -187,42 +221,68 @@ object ThreadT {
case Left(Fork(left, right)) =>
val leading = drawIndent(indent, junc + " Fork") + "\n" + drawIndent(indent, ForkStr)

loop(right(), "", true :: indent, false) map { rightStr =>
loop(right(), "", true :: indent, false, count, refs, deferreds) map { rightStr =>
val acc2 = acc + leading + "\n" + rightStr + "\n"
LoopState(left(), acc2, indent, false).asLeft[String]
LoopState(left(), acc2, indent, false, count, refs, deferreds).asLeft[String]
}

case Left(Cede(results)) =>
val acc2 = acc + drawIndent(indent, junc + " Cede") + "\n"
LoopState(results(), acc2, indent, false).asLeft[String].pure[M]
LoopState(results(), acc2, indent, false, count, refs, deferreds).asLeft[String].pure[M]

case Left(Done) =>
(acc + drawIndent(indent, TurnRight + " Done" + trailing)).asRight[LoopState].pure[M]

case Left(Monitor(f)) =>
val id = new MonitorId
LoopState(f(id), acc, indent, init).asLeft[String].pure[M] // don't render the creation
val id = new MonitorId(count)
LoopState(f(id), acc, indent, init, count + 1, refs, deferreds).asLeft[String].pure[M] // don't render the creation

case Left(Await(id, results)) =>
val acc2 = acc + drawIndent(indent, junc + " Await ") + drawId(id) + "\n"
LoopState(results(), acc2, indent, false).asLeft[String].pure[M]
LoopState(results(), acc2, indent, false, count, refs, deferreds).asLeft[String].pure[M]

case Left(Notify(id, results)) =>
val acc2 = acc + drawIndent(indent, junc + " Notify ") + drawId(id) + "\n"
LoopState(results(), acc2, indent, false).asLeft[String].pure[M]
LoopState(results(), acc2, indent, false, count, refs, deferreds).asLeft[String].pure[M]

case Left(Annotate(name, results)) =>
val acc2 = acc + drawIndent(indent, junc + s" $name") + "\n"
LoopState(results(), acc2, indent, false).asLeft[String].pure[M]
LoopState(results(), acc2, indent, false, count, refs, deferreds).asLeft[String].pure[M]

case Left(Indent(results)) =>
val acc2 = acc + drawIndent(indent, IndentStr) + "\n"
LoopState(results(), acc2, false :: indent, false).asLeft[String].pure[M]
LoopState(results(), acc2, false :: indent, false, count, refs, deferreds).asLeft[String].pure[M]

case Left(Dedent(results)) =>
val indent2 = indent.tail
val acc2 = acc + drawIndent(indent2, DedentStr) + "\n"
LoopState(results(), acc2, indent2, false).asLeft[String].pure[M]
LoopState(results(), acc2, indent2, false, count, refs, deferreds).asLeft[String].pure[M]

case Left(mkRef: MkRef[a, b]) =>
val ref = new Ref[a](mkRef.id)
val acc2 = acc + drawIndent(indent, junc + " Create ref ") + drawRef(ref, mkRef.a) + "\n"
LoopState(mkRef.body(ref), acc2, indent, init, count, refs + (ref.monitorId -> mkRef.a), deferreds).asLeft[String].pure[M]

case Left(modifyRef: ModifyRef[a, b, c]) =>
val a = refs(modifyRef.ref.monitorId).asInstanceOf[a]
val (newA, b) = modifyRef.f(a)
val acc2 = acc + drawIndent(indent, junc + " Modify ref ") + drawRef(modifyRef.ref, newA) + ", produced " + b.toString + "\n"
LoopState(modifyRef.body(b), acc2, indent, init, count, refs.updated(modifyRef.ref.monitorId, newA), deferreds).asLeft[String].pure[M]

case Left(mkDeferred: MkDeferred[a, b]) =>
val deferred = new Deferred[a](mkDeferred.id)
val acc2 = acc + drawIndent(indent, junc + " Create deferred ") + drawDeferred(deferred) + "\n"
LoopState(mkDeferred.body(deferred), acc2, indent, init, count, refs, deferreds).asLeft[String].pure[M]

case Left(tryGetDeferred: TryGetDeferred[a, b]) =>
val optA = deferreds.get(tryGetDeferred.deferred.monitorId).map(_.asInstanceOf[a])
val acc2 = acc + drawIndent(indent, junc + " Try get deferred ") + drawDeferred(tryGetDeferred.deferred) + " = " + optA.toString + "\n"
LoopState(tryGetDeferred.body(optA), acc2, indent, init, count, refs, deferreds).asLeft[String].pure[M]

case Left(completeDeferred: CompleteDeferred[a, b]) =>
val newA = deferreds.get(completeDeferred.deferred.monitorId).map(_.asInstanceOf[a]).getOrElse(completeDeferred.a)
val acc2 = acc + drawIndent(indent, junc + " Complete deferred ") + drawDeferred(completeDeferred.deferred) + " with value " + newA.toString + "\n"
LoopState(completeDeferred.body(), acc2, indent, init, count, refs, deferreds.updated(completeDeferred.deferred.monitorId, newA)).asLeft[String].pure[M]

case Right(a) =>
(acc + drawIndent(indent, TurnRight + " Pure " + a.show + trailing)).asRight[LoopState].pure[M]
Expand All @@ -231,6 +291,6 @@ object ThreadT {
}
}

loop(target, "", Nil, true)
loop(target, "", Nil, true, 0, Map.empty, Map.empty)
}
}
Loading