Skip to content

Commit

Permalink
Merge pull request #4116 from armanbilge/fix/mapref-access-semantics
Browse files Browse the repository at this point in the history
Relax `access` semantics in `MapRef` implementations
  • Loading branch information
armanbilge authored Aug 20, 2024
2 parents 447e140 + 895d210 commit f6690f8
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 51 deletions.
31 changes: 10 additions & 21 deletions std/shared/src/main/scala/cats/effect/std/MapRef.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import cats.effect.kernel._
import cats.syntax.all._

import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicBoolean

/**
* This is a total map from `K` to `Ref[F, V]`.
Expand Down Expand Up @@ -138,28 +137,24 @@ object MapRef extends MapRefCompanionPlatform {

def access: F[(Option[V], Option[V] => F[Boolean])] =
delay {
val hasBeenCalled = new AtomicBoolean(false)
val init = chm.get(k)
if (init == null) {
val set: Option[V] => F[Boolean] = { (opt: Option[V]) =>
opt match {
case None =>
delay(hasBeenCalled.compareAndSet(false, true) && !chm.containsKey(k))
case Some(newV) =>
delay {
// it was initially empty
hasBeenCalled.compareAndSet(false, true) && chm.putIfAbsent(k, newV) == null
}
delay(!chm.containsKey(k))
case Some(newV) => // it was initially empty
delay(chm.putIfAbsent(k, newV) == null)
}
}
(None, set)
} else {
val set: Option[V] => F[Boolean] = { (opt: Option[V]) =>
opt match {
case None =>
delay(hasBeenCalled.compareAndSet(false, true) && chm.remove(k, init))
delay(chm.remove(k, init))
case Some(newV) =>
delay(hasBeenCalled.compareAndSet(false, true) && chm.replace(k, init, newV))
delay(chm.replace(k, init, newV))
}
}
(Some(init), set)
Expand Down Expand Up @@ -305,31 +300,25 @@ object MapRef extends MapRefCompanionPlatform {
class HandleRef(k: K) extends Ref[F, Option[V]] {
def access: F[(Option[V], Option[V] => F[Boolean])] =
sync.delay {
val hasBeenCalled = new AtomicBoolean(false)
val init = map.get(k)
init match {
case None =>
val set: Option[V] => F[Boolean] = { (opt: Option[V]) =>
opt match {
case None =>
sync.delay(hasBeenCalled.compareAndSet(false, true) && !map.contains(k))
case Some(newV) =>
sync.delay {
// it was initially empty
hasBeenCalled
.compareAndSet(false, true) && map.putIfAbsent(k, newV).isEmpty
}
sync.delay(!map.contains(k))
case Some(newV) => // it was initially empty
sync.delay(map.putIfAbsent(k, newV).isEmpty)
}
}
(None, set)
case Some(old) =>
val set: Option[V] => F[Boolean] = { (opt: Option[V]) =>
opt match {
case None =>
sync.delay(hasBeenCalled.compareAndSet(false, true) && map.remove(k, old))
sync.delay(map.remove(k, old))
case Some(newV) =>
sync.delay(
hasBeenCalled.compareAndSet(false, true) && map.replace(k, old, newV))
sync.delay(map.replace(k, old, newV))
}
}
(init, set)
Expand Down
15 changes: 0 additions & 15 deletions tests/jvm/src/test/scala/cats/effect/std/MapRefJVMSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,21 +104,6 @@ class MapRefJVMSpec extends BaseSpec {
op.map(a => a must_=== true)
}

"access - setter should fail if called twice" in real {
val op = for {
r <- MapRef.ofScalaConcurrentTrieMap[IO, Unit, Int]
_ <- r(()).set(Some(0))
accessed <- r(()).access
(value, setter) = accessed
cond1 <- setter(value.map(_ + 1))
_ <- r(()).set(value)
cond2 <- setter(None)
result <- r(()).get
} yield cond1 && !cond2 && result == Some(0)

op.map(a => a must_=== true)
}

"tryUpdate - modification occurs successfully" in real {
val op = for {
r <- MapRef.ofScalaConcurrentTrieMap[IO, Unit, Int]
Expand Down
15 changes: 0 additions & 15 deletions tests/shared/src/test/scala/cats/effect/std/MapRefSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -237,21 +237,6 @@ class MapRefSpec extends BaseSpec {
op.map(a => a must_=== true)
}

"access - setter should fail if called twice" in real {
val op = for {
r <- MapRef.ofConcurrentHashMap[IO, Unit, Int]()
_ <- r(()).set(Some(0))
accessed <- r(()).access
(value, setter) = accessed
cond1 <- setter(value.map(_ + 1))
_ <- r(()).set(value)
cond2 <- setter(None)
result <- r(()).get
} yield cond1 && !cond2 && result == Some(0)

op.map(a => a must_=== true)
}

"tryUpdate - modification occurs successfully" in real {
val op = for {
r <- MapRef.ofConcurrentHashMap[IO, Unit, Int]()
Expand Down

0 comments on commit f6690f8

Please sign in to comment.