diff --git a/python/varspark/test/test_core.py b/python/varspark/test/test_core.py index 01f516d1..c9b9da33 100644 --- a/python/varspark/test/test_core.py +++ b/python/varspark/test/test_core.py @@ -5,7 +5,7 @@ from pyspark import SparkConf from pyspark.sql import SparkSession -from varspark import VariantsContext +from varspark import VariantsContext, RFModelContext from varspark.test import find_variants_jar, PROJECT_DIR THIS_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -28,7 +28,8 @@ def tearDownClass(self): class VariantSparkAPITestCase(VariantSparkPySparkTestCase): - + # self._ variables are only accessible from other tests if initialised here + # Would it be better to include model, importance, and fdr definitions here to support multiple unit tests? def setUp(self): self.spark = SparkSession(self.sc) self.vc = VariantsContext(self.spark) @@ -39,22 +40,25 @@ def test_variants_context_parameter_type(self): self.assertEqual('keyword argument label_file_path = 123 doesn\'t match signature str', str(cm.exception)) - def test_importance_analysis_from_vcf(self): + def test_rfmodel(self): label_data_path = os.path.join(PROJECT_DIR, 'data/chr22-labels.csv') label = self.vc.load_label(label_file_path=label_data_path, col_name='22_16050678') feature_data_path = os.path.join(PROJECT_DIR, 'data/chr22_1000.vcf') features = self.vc.import_vcf(vcf_file_path=feature_data_path) - - imp_analysis = features.importance_analysis(label, 200, None, True, 17, 50, 3) + rf = RFModelContext(self.spark, mtry_fraction=None, oob=True, seed=17, var_ordinal_levels=3) + rf.fit_trees(features, label, n_trees=200, batch_size=50) + imp_analysis = rf.importance_analysis() imp_vars = imp_analysis.important_variables(20) - most_imp_var = imp_vars[0][0] + most_imp_var = imp_vars['variable'][0] self.assertEqual('22_16050678_C_T', most_imp_var) - df = imp_analysis.variable_importance() + df = imp_analysis.variable_importance(normalized=True) self.assertEqual('22_16050678_C_T', - str(df.orderBy('importance', ascending=False).collect()[0][0])) - oob_error = imp_analysis.oob_error() - self.assertAlmostEqual(0.004578754578754579, oob_error, 4) - + str(df.sort_values(by='importance', ascending=False)['variant_id'].iloc[0])) + oob_error = rf.oob_error() + self.assertEqual(0.004578754578754579, oob_error) + fdrCalc = rf.get_lfdr() + _, fdr = fdrCalc.compute_fdr(countThreshold = 2, local_fdr_cutoff = 0.05) + self.assertEqual(0.0002976892628282768, fdr) if __name__ == '__main__': unittest.main(verbosity=2)