Skip to content

Commit

Permalink
Merge pull request #827 from jwojnowski/ref-lens
Browse files Browse the repository at this point in the history
Add Ref.lens constructor
  • Loading branch information
djspiewak authored May 2, 2020
2 parents c2fb504 + d02a15a commit 912bb20
Show file tree
Hide file tree
Showing 2 changed files with 355 additions and 0 deletions.
83 changes: 83 additions & 0 deletions core/shared/src/main/scala/cats/effect/concurrent/Ref.scala
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,23 @@ object Ref {
*/
def unsafe[F[_], A](a: A)(implicit F: Sync[F]): Ref[F, A] = new SyncRef[F, A](new AtomicReference[A](a))

/**
* Creates an instance focused on a component of another Ref's value.
* Delegates every get and modification to underlying Ref, so both instances are always in sync.
*
* Example:
*
* {{{
* case class Foo(bar: String, baz: Int)
*
* val refA: Ref[IO, Foo] = ???
* val refB: Ref[IO, String] =
* Ref.lens[IO, Foo, String](refA)(_.bar, (foo: Foo) => (bar: String) => foo.copy(bar = bar))
* }}}
* */
def lens[F[_], A, B <: AnyRef](ref: Ref[F, A])(get: A => B, set: A => B => A)(implicit F: Sync[F]): Ref[F, B] =
new LensRef[F, A, B](ref)(get, set)

final class ApplyBuilders[F[_]](val F: Sync[F]) extends AnyVal {

/**
Expand Down Expand Up @@ -303,6 +320,72 @@ object Ref {
trans(F.compose[(A, *)].compose[A => *].map(underlying.access)(trans(_)))
}

final private[concurrent] class LensRef[F[_], A, B <: AnyRef](underlying: Ref[F, A])(
lensGet: A => B,
lensSet: A => B => A
)(implicit F: Sync[F])
extends Ref[F, B] {
override def get: F[B] = F.map(underlying.get)(a => lensGet(a))

override def set(b: B): F[Unit] = underlying.update(a => lensModify(a)(_ => b))

override def getAndSet(b: B): F[B] = underlying.modify { a =>
(lensModify(a)(_ => b), lensGet(a))
}

override def update(f: B => B): F[Unit] =
underlying.update(a => lensModify(a)(f))

override def modify[C](f: B => (B, C)): F[C] =
underlying.modify { a =>
val oldB = lensGet(a)
val (b, c) = f(oldB)
(lensSet(a)(b), c)
}

override def tryUpdate(f: B => B): F[Boolean] =
F.map(tryModify(a => (f(a), ())))(_.isDefined)

override def tryModify[C](f: B => (B, C)): F[Option[C]] =
underlying.tryModify { a =>
val oldB = lensGet(a)
val (b, result) = f(oldB)
(lensSet(a)(b), result)
}

override def tryModifyState[C](state: State[B, C]): F[Option[C]] = {
val f = state.runF.value
tryModify(a => f(a).value)
}

override def modifyState[C](state: State[B, C]): F[C] = {
val f = state.runF.value
modify(a => f(a).value)
}

override val access: F[(B, B => F[Boolean])] =
F.flatMap(underlying.get) { snapshotA =>
val snapshotB = lensGet(snapshotA)
val setter = F.delay {
val hasBeenCalled = new AtomicBoolean(false)

(b: B) => {
F.flatMap(F.delay(hasBeenCalled.compareAndSet(false, true))) { hasBeenCalled =>
F.map(underlying.tryModify { a =>
if (hasBeenCalled && (lensGet(a) eq snapshotB))
(lensSet(a)(b), true)
else
(a, false)
})(_.getOrElse(false))
}
}
}
setter.tupleLeft(snapshotB)
}

private def lensModify(s: A)(f: B => B): A = lensSet(s)(f(lensGet(s)))
}

implicit def catsInvariantForRef[F[_]: Functor]: Invariant[Ref[F, *]] =
new Invariant[Ref[F, *]] {
override def imap[A, B](fa: Ref[F, A])(f: A => B)(g: B => A): Ref[F, B] =
Expand Down
272 changes: 272 additions & 0 deletions core/shared/src/test/scala/cats/effect/concurrent/LensRefTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
/*
* Copyright (c) 2017-2019 The Typelevel Cats-effect Project Developers
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cats.effect.concurrent

import cats.data.State
import cats.effect.IO
import org.scalatest.Succeeded
import org.scalatest.compatible.Assertion
import org.scalatest.funsuite.AsyncFunSuite
import org.scalatest.matchers.should.Matchers

import scala.concurrent.Future

class LensRefTests extends AsyncFunSuite with Matchers {

private def run(t: IO[Unit]): Future[Assertion] = t.as(Succeeded).unsafeToFuture()

case class Foo(bar: Integer, baz: Integer)

object Foo {
def get(foo: Foo): Integer = foo.bar

def set(foo: Foo)(bar: Integer): Foo = foo.copy(bar = bar)
}

test("get - returns B") {
val op = for {
refA <- Ref[IO].of(Foo(0, -1))
refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set)
result <- refB.get
} yield result

run(op.map(_ shouldEqual 0))
}

test("set - modifies underlying Ref") {
val op = for {
refA <- Ref[IO].of(Foo(0, -1))
refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set)
_ <- refB.set(1)
result <- refA.get
} yield result

run(op.map(_ shouldEqual Foo(1, -1)))
}

test("getAndSet - modifies underlying Ref and returns previous value") {
val op = for {
refA <- Ref[IO].of(Foo(0, -1))
refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set)
oldValue <- refB.getAndSet(1)
a <- refA.get
} yield (oldValue, a)

run(op.map {
case (oldValue, a) =>
oldValue shouldBe 0
a shouldEqual Foo(1, -1)
})
}

test("update - modifies underlying Ref") {
val op = for {
refA <- Ref[IO].of(Foo(0, -1))
refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set)
_ <- refB.update(_ + 1)
a <- refA.get
} yield a

run(op.map(_ shouldBe Foo(1, -1)))
}

test("modify - modifies underlying Ref and returns a value") {
val op = for {
refA <- Ref[IO].of(Foo(0, -1))
refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set)
result <- refB.modify(bar => (bar + 1, 10))
a <- refA.get
} yield (result, a)

run(op.map {
case (result, a) =>
result shouldBe 10
a shouldEqual Foo(1, -1)
})
}

test("tryUpdate - successfully modifies underlying Ref") {
val op = for {
refA <- Ref[IO].of(Foo(0, -1))
refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set)
result <- refB.tryUpdate(_ + 1)
a <- refA.get
} yield (result, a)

run(op.map {
case (result, a) =>
result shouldBe true
a shouldBe Foo(1, -1)
})
}

test("tryUpdate - fails to modify original value if it's already been modified concurrently") {
val updateRefUnsafely: Ref[IO, Integer] => Unit = _.set(5).unsafeRunSync()

val op = for {
refA <- Ref[IO].of(Foo(0, -1))
refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set)
result <- refB.tryUpdate { currentValue =>
updateRefUnsafely(refB)
currentValue + 1
}
a <- refA.get
} yield (result, a)

run(op.map {
case (result, a) =>
result shouldBe false
a shouldBe Foo(5, -1)
})
}

test("tryModify - successfully modifies underlying Ref") {
val op = for {
refA <- Ref[IO].of(Foo(0, -1))
refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set)
result <- refB.tryModify(bar => (bar + 1, "A"))
a <- refA.get
} yield (result, a)

run(op.map {
case (result, a) =>
result shouldBe Some("A")
a shouldBe Foo(1, -1)
})
}

test("tryModify - fails to modify original value if it's already been modified concurrently") {
val updateRefUnsafely: Ref[IO, Integer] => Unit = _.set(5).unsafeRunSync()

val op = for {
refA <- Ref[IO].of(Foo(0, -1))
refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set)
result <- refB.tryModify { currentValue =>
updateRefUnsafely(refB)
(currentValue + 1, 10)
}
a <- refA.get
} yield (result, a)

run(op.map {
case (result, a) =>
result shouldBe None
a shouldBe Foo(5, -1)
})
}

test("tryModifyState - successfully modifies underlying Ref") {
val op = for {
refA <- Ref[IO].of(Foo(0, -1))
refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set)
result <- refB.tryModifyState(State.apply(x => (x + 1, "A")))
a <- refA.get
} yield (result, a)

run(op.map {
case (result, a) =>
result shouldBe Some("A")
a shouldBe Foo(1, -1)
})
}

test("modifyState - successfully modifies underlying Ref") {
val op = for {
refA <- Ref[IO].of(Foo(0, -1))
refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set)
result <- refB.modifyState(State.apply(x => (x + 1, "A")))
a <- refA.get
} yield (result, a)

run(op.map {
case (result, a) =>
result shouldBe "A"
a shouldBe Foo(1, -1)
})
}

test("access - successfully modifies underlying Ref") {
val op = for {
refA <- Ref[IO].of(Foo(0, -1))
refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set)
valueAndSetter <- refB.access
(value, setter) = valueAndSetter
success <- setter(value + 1)
a <- refA.get
} yield (success, a)
run(op.map {
case (success, a) =>
success shouldBe true
a shouldBe Foo(1, -1)
}.void)
}

test("access - successfully modifies underlying Ref after A is modified without affecting B") {
val op = for {
refA <- Ref[IO].of(Foo(0, -1))
refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set)
valueAndSetter <- refB.access
(value, setter) = valueAndSetter
_ <- refA.update(_.copy(baz = -2))
success <- setter(value + 1)
a <- refA.get
} yield (success, a)
run(op.map {
case (success, a) =>
success shouldBe true
a shouldBe Foo(1, -2)
}.void)
}

test("access - setter fails to modify underlying Ref if value is modified before setter is called") {
val op = for {
refA <- Ref[IO].of(Foo(0, -1))
refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set)
valueAndSetter <- refB.access
(value, setter) = valueAndSetter
_ <- refA.set(Foo(5, -1))
success <- setter(value + 1)
a <- refA.get
} yield (success, a)

run(op.map {
case (success, result) =>
success shouldBe false
result shouldBe Foo(5, -1)
}.void)
}

test("access - setter fails the second time") {
val op = for {
refA <- Ref[IO].of(Foo(0, -1))
refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set)
valueAndSetter <- refB.access
(_, setter) = valueAndSetter
result1 <- setter(1)
result2 <- setter(2)
a <- refA.get
} yield (result1, result2, a)

run(op.map {
case (result1, result2, a) =>
result1 shouldBe true
result2 shouldBe false
a shouldBe Foo(1, -1)
}.void)
}

}

0 comments on commit 912bb20

Please sign in to comment.