Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TestableRandom and ThreadLocalSecureRandom #56

Merged
merged 1 commit into from
Sep 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions lila/src/main/scala/Random.scala
Original file line number Diff line number Diff line change
@@ -1,29 +1,44 @@
package scalalib

object ThreadLocalRandom extends RandomApi:
protected[scalalib] def impl = java.util.concurrent.ThreadLocalRandom.current()
protected def impl = java.util.concurrent.ThreadLocalRandom.current()

/** Thread-safe, but the underlying impl, [[java.security.SecureRandom]], uses synchronized methods, which has
* degraded performance under high contention. See [[ThreadLocalSecureRandom]].
*/
object SecureRandom extends RandomApi:
protected val impl = java.security.SecureRandom.getInstanceStrong()

private abstract class RandomApi:
/** An alternative to [[SecureRandom]] which offers improved performance under high contention.
*/
object ThreadLocalSecureRandom extends RandomApi:
private val store = new java.lang.ThreadLocal[java.util.Random]:
override def initialValue = java.security.SecureRandom.getInstanceStrong()
protected def impl = store.get

/** A deterministic random number generator for testing purposes.
*/
final class TestableRandom(seed: Long) extends RandomApi:
protected val impl = new java.util.Random(seed)

sealed abstract class RandomApi:
protected def impl: java.util.Random

def nextBoolean() = impl.nextBoolean
def nextDouble() = impl.nextDouble
def nextFloat() = impl.nextFloat
def nextGaussian() = impl.nextGaussian
def nextInt() = impl.nextInt
def nextInt(n: Int) = impl.nextInt(n)
def nextLong() = impl.nextLong
def nextLong(l: Long) = impl.nextLong(l)
final def nextBoolean() = impl.nextBoolean
final def nextDouble() = impl.nextDouble
final def nextFloat() = impl.nextFloat
final def nextGaussian() = impl.nextGaussian
final def nextInt() = impl.nextInt
final def nextInt(n: Int) = impl.nextInt(n)
final def nextLong() = impl.nextLong
final def nextLong(l: Long) = impl.nextLong(l)

def nextBytes(len: Int): Array[Byte] =
final def nextBytes(len: Int): Array[Byte] =
val bytes = new Array[Byte](len)
impl.nextBytes(bytes)
bytes

def nextString(len: Int): String =
final def nextString(len: Int): String =
val randomImpl = impl
val chars = RandomApi.chars
val arr = new Array[Char](len)
Expand All @@ -34,17 +49,17 @@ private abstract class RandomApi:
i += 1
String.valueOf(arr)

def shuffle[T, C](xs: IterableOnce[T])(using scala.collection.BuildFrom[xs.type, T, C]): C =
final def shuffle[T, C](xs: IterableOnce[T])(using scala.collection.BuildFrom[xs.type, T, C]): C =
scala.util.Random(impl).shuffle(xs)

def oneOf[A](seq: scala.collection.IndexedSeq[A]): Option[A] =
final def oneOf[A](seq: scala.collection.IndexedSeq[A]): Option[A] =
val len = seq.length
if len > 0 then Some(seq(impl.nextInt(len))) else None

// odds(1) = 100% true
// odds(2) = 50% true
// odds(3) = 33% true
def odds(n: Int): Boolean = impl.nextFloat() * n < 1
final def odds(n: Int): Boolean = impl.nextFloat() * n < 1f

private object RandomApi:
// private vals are accessed directly as a static field.
Expand Down