diff --git a/core/shared/src/main/scala/cats/effect/concurrent/Ref.scala b/core/shared/src/main/scala/cats/effect/concurrent/Ref.scala index 381cd3cae2..7c4c86d421 100644 --- a/core/shared/src/main/scala/cats/effect/concurrent/Ref.scala +++ b/core/shared/src/main/scala/cats/effect/concurrent/Ref.scala @@ -226,6 +226,23 @@ object Ref { */ def unsafe[F[_], A](a: A)(implicit F: Sync[F]): Ref[F, A] = new SyncRef[F, A](new AtomicReference[A](a)) + /** + * Creates an instance focused on a component of another Ref's value. + * Delegates every get and modification to underlying Ref, so both instances are always in sync. + * + * Example: + * + * {{{ + * case class Foo(bar: String, baz: Int) + * + * val refA: Ref[IO, Foo] = ??? + * val refB: Ref[IO, String] = + * Ref.lens[IO, Foo, String](refA)(_.bar, (foo: Foo) => (bar: String) => foo.copy(bar = bar)) + * }}} + * */ + def lens[F[_], A, B <: AnyRef](ref: Ref[F, A])(get: A => B, set: A => B => A)(implicit F: Sync[F]): Ref[F, B] = + new LensRef[F, A, B](ref)(get, set) + final class ApplyBuilders[F[_]](val F: Sync[F]) extends AnyVal { /** @@ -303,6 +320,72 @@ object Ref { trans(F.compose[(A, *)].compose[A => *].map(underlying.access)(trans(_))) } + final private[concurrent] class LensRef[F[_], A, B <: AnyRef](underlying: Ref[F, A])( + lensGet: A => B, + lensSet: A => B => A + )(implicit F: Sync[F]) + extends Ref[F, B] { + override def get: F[B] = F.map(underlying.get)(a => lensGet(a)) + + override def set(b: B): F[Unit] = underlying.update(a => lensModify(a)(_ => b)) + + override def getAndSet(b: B): F[B] = underlying.modify { a => + (lensModify(a)(_ => b), lensGet(a)) + } + + override def update(f: B => B): F[Unit] = + underlying.update(a => lensModify(a)(f)) + + override def modify[C](f: B => (B, C)): F[C] = + underlying.modify { a => + val oldB = lensGet(a) + val (b, c) = f(oldB) + (lensSet(a)(b), c) + } + + override def tryUpdate(f: B => B): F[Boolean] = + F.map(tryModify(a => (f(a), ())))(_.isDefined) + + override def tryModify[C](f: B => (B, C)): F[Option[C]] = + underlying.tryModify { a => + val oldB = lensGet(a) + val (b, result) = f(oldB) + (lensSet(a)(b), result) + } + + override def tryModifyState[C](state: State[B, C]): F[Option[C]] = { + val f = state.runF.value + tryModify(a => f(a).value) + } + + override def modifyState[C](state: State[B, C]): F[C] = { + val f = state.runF.value + modify(a => f(a).value) + } + + override val access: F[(B, B => F[Boolean])] = + F.flatMap(underlying.get) { snapshotA => + val snapshotB = lensGet(snapshotA) + val setter = F.delay { + val hasBeenCalled = new AtomicBoolean(false) + + (b: B) => { + F.flatMap(F.delay(hasBeenCalled.compareAndSet(false, true))) { hasBeenCalled => + F.map(underlying.tryModify { a => + if (hasBeenCalled && (lensGet(a) eq snapshotB)) + (lensSet(a)(b), true) + else + (a, false) + })(_.getOrElse(false)) + } + } + } + setter.tupleLeft(snapshotB) + } + + private def lensModify(s: A)(f: B => B): A = lensSet(s)(f(lensGet(s))) + } + implicit def catsInvariantForRef[F[_]: Functor]: Invariant[Ref[F, *]] = new Invariant[Ref[F, *]] { override def imap[A, B](fa: Ref[F, A])(f: A => B)(g: B => A): Ref[F, B] = diff --git a/core/shared/src/test/scala/cats/effect/concurrent/LensRefTests.scala b/core/shared/src/test/scala/cats/effect/concurrent/LensRefTests.scala new file mode 100644 index 0000000000..f456992e9f --- /dev/null +++ b/core/shared/src/test/scala/cats/effect/concurrent/LensRefTests.scala @@ -0,0 +1,272 @@ +/* + * Copyright (c) 2017-2019 The Typelevel Cats-effect Project Developers + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cats.effect.concurrent + +import cats.data.State +import cats.effect.IO +import org.scalatest.Succeeded +import org.scalatest.compatible.Assertion +import org.scalatest.funsuite.AsyncFunSuite +import org.scalatest.matchers.should.Matchers + +import scala.concurrent.Future + +class LensRefTests extends AsyncFunSuite with Matchers { + + private def run(t: IO[Unit]): Future[Assertion] = t.as(Succeeded).unsafeToFuture() + + case class Foo(bar: Integer, baz: Integer) + + object Foo { + def get(foo: Foo): Integer = foo.bar + + def set(foo: Foo)(bar: Integer): Foo = foo.copy(bar = bar) + } + + test("get - returns B") { + val op = for { + refA <- Ref[IO].of(Foo(0, -1)) + refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set) + result <- refB.get + } yield result + + run(op.map(_ shouldEqual 0)) + } + + test("set - modifies underlying Ref") { + val op = for { + refA <- Ref[IO].of(Foo(0, -1)) + refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set) + _ <- refB.set(1) + result <- refA.get + } yield result + + run(op.map(_ shouldEqual Foo(1, -1))) + } + + test("getAndSet - modifies underlying Ref and returns previous value") { + val op = for { + refA <- Ref[IO].of(Foo(0, -1)) + refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set) + oldValue <- refB.getAndSet(1) + a <- refA.get + } yield (oldValue, a) + + run(op.map { + case (oldValue, a) => + oldValue shouldBe 0 + a shouldEqual Foo(1, -1) + }) + } + + test("update - modifies underlying Ref") { + val op = for { + refA <- Ref[IO].of(Foo(0, -1)) + refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set) + _ <- refB.update(_ + 1) + a <- refA.get + } yield a + + run(op.map(_ shouldBe Foo(1, -1))) + } + + test("modify - modifies underlying Ref and returns a value") { + val op = for { + refA <- Ref[IO].of(Foo(0, -1)) + refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set) + result <- refB.modify(bar => (bar + 1, 10)) + a <- refA.get + } yield (result, a) + + run(op.map { + case (result, a) => + result shouldBe 10 + a shouldEqual Foo(1, -1) + }) + } + + test("tryUpdate - successfully modifies underlying Ref") { + val op = for { + refA <- Ref[IO].of(Foo(0, -1)) + refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set) + result <- refB.tryUpdate(_ + 1) + a <- refA.get + } yield (result, a) + + run(op.map { + case (result, a) => + result shouldBe true + a shouldBe Foo(1, -1) + }) + } + + test("tryUpdate - fails to modify original value if it's already been modified concurrently") { + val updateRefUnsafely: Ref[IO, Integer] => Unit = _.set(5).unsafeRunSync() + + val op = for { + refA <- Ref[IO].of(Foo(0, -1)) + refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set) + result <- refB.tryUpdate { currentValue => + updateRefUnsafely(refB) + currentValue + 1 + } + a <- refA.get + } yield (result, a) + + run(op.map { + case (result, a) => + result shouldBe false + a shouldBe Foo(5, -1) + }) + } + + test("tryModify - successfully modifies underlying Ref") { + val op = for { + refA <- Ref[IO].of(Foo(0, -1)) + refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set) + result <- refB.tryModify(bar => (bar + 1, "A")) + a <- refA.get + } yield (result, a) + + run(op.map { + case (result, a) => + result shouldBe Some("A") + a shouldBe Foo(1, -1) + }) + } + + test("tryModify - fails to modify original value if it's already been modified concurrently") { + val updateRefUnsafely: Ref[IO, Integer] => Unit = _.set(5).unsafeRunSync() + + val op = for { + refA <- Ref[IO].of(Foo(0, -1)) + refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set) + result <- refB.tryModify { currentValue => + updateRefUnsafely(refB) + (currentValue + 1, 10) + } + a <- refA.get + } yield (result, a) + + run(op.map { + case (result, a) => + result shouldBe None + a shouldBe Foo(5, -1) + }) + } + + test("tryModifyState - successfully modifies underlying Ref") { + val op = for { + refA <- Ref[IO].of(Foo(0, -1)) + refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set) + result <- refB.tryModifyState(State.apply(x => (x + 1, "A"))) + a <- refA.get + } yield (result, a) + + run(op.map { + case (result, a) => + result shouldBe Some("A") + a shouldBe Foo(1, -1) + }) + } + + test("modifyState - successfully modifies underlying Ref") { + val op = for { + refA <- Ref[IO].of(Foo(0, -1)) + refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set) + result <- refB.modifyState(State.apply(x => (x + 1, "A"))) + a <- refA.get + } yield (result, a) + + run(op.map { + case (result, a) => + result shouldBe "A" + a shouldBe Foo(1, -1) + }) + } + + test("access - successfully modifies underlying Ref") { + val op = for { + refA <- Ref[IO].of(Foo(0, -1)) + refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set) + valueAndSetter <- refB.access + (value, setter) = valueAndSetter + success <- setter(value + 1) + a <- refA.get + } yield (success, a) + run(op.map { + case (success, a) => + success shouldBe true + a shouldBe Foo(1, -1) + }.void) + } + + test("access - successfully modifies underlying Ref after A is modified without affecting B") { + val op = for { + refA <- Ref[IO].of(Foo(0, -1)) + refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set) + valueAndSetter <- refB.access + (value, setter) = valueAndSetter + _ <- refA.update(_.copy(baz = -2)) + success <- setter(value + 1) + a <- refA.get + } yield (success, a) + run(op.map { + case (success, a) => + success shouldBe true + a shouldBe Foo(1, -2) + }.void) + } + + test("access - setter fails to modify underlying Ref if value is modified before setter is called") { + val op = for { + refA <- Ref[IO].of(Foo(0, -1)) + refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set) + valueAndSetter <- refB.access + (value, setter) = valueAndSetter + _ <- refA.set(Foo(5, -1)) + success <- setter(value + 1) + a <- refA.get + } yield (success, a) + + run(op.map { + case (success, result) => + success shouldBe false + result shouldBe Foo(5, -1) + }.void) + } + + test("access - setter fails the second time") { + val op = for { + refA <- Ref[IO].of(Foo(0, -1)) + refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set) + valueAndSetter <- refB.access + (_, setter) = valueAndSetter + result1 <- setter(1) + result2 <- setter(2) + a <- refA.get + } yield (result1, result2, a) + + run(op.map { + case (result1, result2, a) => + result1 shouldBe true + result2 shouldBe false + a shouldBe Foo(1, -1) + }.void) + } + +}