diff --git a/core/shared/src/main/scala/cats/effect/concurrent/Ref.scala b/core/shared/src/main/scala/cats/effect/concurrent/Ref.scala index 623aeef6bc..7c4c86d421 100644 --- a/core/shared/src/main/scala/cats/effect/concurrent/Ref.scala +++ b/core/shared/src/main/scala/cats/effect/concurrent/Ref.scala @@ -240,7 +240,7 @@ object Ref { * Ref.lens[IO, Foo, String](refA)(_.bar, (foo: Foo) => (bar: String) => foo.copy(bar = bar)) * }}} * */ - def lens[F[_], A, B <: AnyRef](ref: Ref[F, A])(get: A => B, set: A => B => A)(implicit F: Functor[F]): Ref[F, B] = + def lens[F[_], A, B <: AnyRef](ref: Ref[F, A])(get: A => B, set: A => B => A)(implicit F: Sync[F]): Ref[F, B] = new LensRef[F, A, B](ref)(get, set) final class ApplyBuilders[F[_]](val F: Sync[F]) extends AnyVal { @@ -323,7 +323,7 @@ object Ref { final private[concurrent] class LensRef[F[_], A, B <: AnyRef](underlying: Ref[F, A])( lensGet: A => B, lensSet: A => B => A - )(implicit F: Functor[F]) + )(implicit F: Sync[F]) extends Ref[F, B] { override def get: F[B] = F.map(underlying.get)(a => lensGet(a)) @@ -363,19 +363,24 @@ object Ref { modify(a => f(a).value) } - override def access: F[(B, B => F[Boolean])] = - F.map(underlying.get) { snapshotA => + override val access: F[(B, B => F[Boolean])] = + F.flatMap(underlying.get) { snapshotA => val snapshotB = lensGet(snapshotA) - val hasBeenCalled = new AtomicBoolean(false) - val setter = (b: B) => { - F.map(underlying.tryModify { a => - if (hasBeenCalled.compareAndSet(false, true) && (lensGet(a) eq snapshotB)) - (lensSet(a)(b), true) - else - (a, false) - })(_.getOrElse(false)) + val setter = F.delay { + val hasBeenCalled = new AtomicBoolean(false) + + (b: B) => { + F.flatMap(F.delay(hasBeenCalled.compareAndSet(false, true))) { hasBeenCalled => + F.map(underlying.tryModify { a => + if (hasBeenCalled && (lensGet(a) eq snapshotB)) + (lensSet(a)(b), true) + else + (a, false) + })(_.getOrElse(false)) + } + } } - (snapshotB, setter) + setter.tupleLeft(snapshotB) } private def lensModify(s: A)(f: B => B): A = lensSet(s)(f(lensGet(s))) diff --git a/core/shared/src/test/scala/cats/effect/concurrent/LensRefTests.scala b/core/shared/src/test/scala/cats/effect/concurrent/LensRefTests.scala index 9cef928a66..f456992e9f 100644 --- a/core/shared/src/test/scala/cats/effect/concurrent/LensRefTests.scala +++ b/core/shared/src/test/scala/cats/effect/concurrent/LensRefTests.scala @@ -215,6 +215,23 @@ class LensRefTests extends AsyncFunSuite with Matchers { }.void) } + test("access - successfully modifies underlying Ref after A is modified without affecting B") { + val op = for { + refA <- Ref[IO].of(Foo(0, -1)) + refB = Ref.lens[IO, Foo, Integer](refA)(Foo.get, Foo.set) + valueAndSetter <- refB.access + (value, setter) = valueAndSetter + _ <- refA.update(_.copy(baz = -2)) + success <- setter(value + 1) + a <- refA.get + } yield (success, a) + run(op.map { + case (success, a) => + success shouldBe true + a shouldBe Foo(1, -2) + }.void) + } + test("access - setter fails to modify underlying Ref if value is modified before setter is called") { val op = for { refA <- Ref[IO].of(Foo(0, -1))