Skip to content

Commit

Permalink
Merge pull request #56 from isaacl/addThreadLocalSecure
Browse files Browse the repository at this point in the history
Add TestableRandom and ThreadLocalSecureRandom
  • Loading branch information
ornicar authored Sep 16, 2024
2 parents 7e1d5e1 + df783fc commit 4da5a79
Showing 1 changed file with 30 additions and 15 deletions.
45 changes: 30 additions & 15 deletions lila/src/main/scala/Random.scala
Original file line number Diff line number Diff line change
@@ -1,29 +1,44 @@
package scalalib

object ThreadLocalRandom extends RandomApi:
protected[scalalib] def impl = java.util.concurrent.ThreadLocalRandom.current()
protected def impl = java.util.concurrent.ThreadLocalRandom.current()

/** Thread-safe, but the underlying impl, [[java.security.SecureRandom]], uses synchronized methods, which has
* degraded performance under high contention. See [[ThreadLocalSecureRandom]].
*/
object SecureRandom extends RandomApi:
protected val impl = java.security.SecureRandom.getInstanceStrong()

private abstract class RandomApi:
/** An alternative to [[SecureRandom]] which offers improved performance under high contention.
*/
object ThreadLocalSecureRandom extends RandomApi:
private val store = new java.lang.ThreadLocal[java.util.Random]:
override def initialValue = java.security.SecureRandom.getInstanceStrong()
protected def impl = store.get

/** A deterministic random number generator for testing purposes.
*/
final class TestableRandom(seed: Long) extends RandomApi:
protected val impl = new java.util.Random(seed)

sealed abstract class RandomApi:
protected def impl: java.util.Random

def nextBoolean() = impl.nextBoolean
def nextDouble() = impl.nextDouble
def nextFloat() = impl.nextFloat
def nextGaussian() = impl.nextGaussian
def nextInt() = impl.nextInt
def nextInt(n: Int) = impl.nextInt(n)
def nextLong() = impl.nextLong
def nextLong(l: Long) = impl.nextLong(l)
final def nextBoolean() = impl.nextBoolean
final def nextDouble() = impl.nextDouble
final def nextFloat() = impl.nextFloat
final def nextGaussian() = impl.nextGaussian
final def nextInt() = impl.nextInt
final def nextInt(n: Int) = impl.nextInt(n)
final def nextLong() = impl.nextLong
final def nextLong(l: Long) = impl.nextLong(l)

def nextBytes(len: Int): Array[Byte] =
final def nextBytes(len: Int): Array[Byte] =
val bytes = new Array[Byte](len)
impl.nextBytes(bytes)
bytes

def nextString(len: Int): String =
final def nextString(len: Int): String =
val randomImpl = impl
val chars = RandomApi.chars
val arr = new Array[Char](len)
Expand All @@ -34,17 +49,17 @@ private abstract class RandomApi:
i += 1
String.valueOf(arr)

def shuffle[T, C](xs: IterableOnce[T])(using scala.collection.BuildFrom[xs.type, T, C]): C =
final def shuffle[T, C](xs: IterableOnce[T])(using scala.collection.BuildFrom[xs.type, T, C]): C =
scala.util.Random(impl).shuffle(xs)

def oneOf[A](seq: scala.collection.IndexedSeq[A]): Option[A] =
final def oneOf[A](seq: scala.collection.IndexedSeq[A]): Option[A] =
val len = seq.length
if len > 0 then Some(seq(impl.nextInt(len))) else None

// odds(1) = 100% true
// odds(2) = 50% true
// odds(3) = 33% true
def odds(n: Int): Boolean = impl.nextFloat() * n < 1
final def odds(n: Int): Boolean = impl.nextFloat() * n < 1f

private object RandomApi:
// private vals are accessed directly as a static field.
Expand Down

0 comments on commit 4da5a79

Please sign in to comment.