diff --git a/src/test/scala/au/csiro/variantspark/misc/ReproducibilityTest.scala b/src/test/scala/au/csiro/variantspark/misc/ReproducibilityTest.scala index d789c099..6f4e82d1 100644 --- a/src/test/scala/au/csiro/variantspark/misc/ReproducibilityTest.scala +++ b/src/test/scala/au/csiro/variantspark/misc/ReproducibilityTest.scala @@ -6,7 +6,8 @@ import org.junit.Test import org.junit.Ignore import org.junit.Assert._ import au.csiro.variantspark.api._ -import org.apache.spark.sql.SparkSession +import au.csiro.variantspark.algo.RandomForestParams +import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.SparkConf /** @@ -21,18 +22,20 @@ class ReproducibilityTest extends SparkTest { .getOrCreate() @Test - @Ignore def testReproducibleResults() { implicit val vsContext = VSContext(spark) + implicit val sqlContext = spark.sqlContext val features = vsContext.importVCF("data/chr22_1000.vcf", 3) val label = vsContext.loadLabel("data/chr22-labels.csv", "22_16051249") - val impAnalysis1 = features.importanceAnalysis(label, nTrees = 40, seed = Some(13L), - mtryFraction = None, batchSize = 20) - val topVariables1 = impAnalysis1.importantVariables(20) + val params = RandomForestParams(seed = 13L) + val rfModel1 = RFModelTrainer.trainModel(features, label, params, 40, 20) + val impAnalysis1 = new ImportanceAnalysis(sqlContext, features, rfModel1) + val topVariables1 = impAnalysis1.importantVariables(20, false) topVariables1.foreach(println) - val impAnalysis2 = features.importanceAnalysis(label, nTrees = 40, seed = Some(13L), - mtryFraction = None, batchSize = 20) - val topVariables2 = impAnalysis2.importantVariables(20) + println() + val rfModel2 = RFModelTrainer.trainModel(features, label, params, 40, 20) + val impAnalysis2 = new ImportanceAnalysis(sqlContext, features, rfModel2) + val topVariables2 = impAnalysis2.importantVariables(20, false) topVariables2.foreach(println) topVariables1.zip(topVariables2).foreach { p => assertEquals(p._1, p._2) } }