diff --git a/src/main/scala/au/csiro/variantspark/input/CsvStdFeatureSource.scala b/src/main/scala/au/csiro/variantspark/input/CsvStdFeatureSource.scala index e6994911..d6d99344 100644 --- a/src/main/scala/au/csiro/variantspark/input/CsvStdFeatureSource.scala +++ b/src/main/scala/au/csiro/variantspark/input/CsvStdFeatureSource.scala @@ -14,6 +14,7 @@ import au.csiro.variantspark.data.Feature import au.csiro.variantspark.data.FeatureBuilder import au.csiro.variantspark.data.DataBuilder import au.csiro.variantspark.data.StdFeature +import au.csiro.variantspark.data.DefRepresentationFactory import org.apache.spark.broadcast.Broadcast class MapAccumulator @@ -80,11 +81,13 @@ class MapAccumulator * ingestion of traditional CSV files for analysis. */ case class CsvStdFeatureSource[V](data: RDD[String], - defaultType: VariableType = ContinuousVariable, csvFormat: CSVFormat = DefaultCSVFormatSpec) + defaultType: VariableType = ContinuousVariable, + optVariableTypes: Option[RDD[String]] = None, csvFormat: CSVFormat = DefaultCSVFormatSpec) extends FeatureSource { val variableNames: List[String] = new CSVParser(csvFormat).parseLine(data.first).get.tail val br_variableNames: Broadcast[List[String]] = data.context.broadcast(variableNames) + val br_types = data.context.broadcast(optVariableTypes.map(parseTypes)) lazy val transposedData: RDD[(String, Array[String])] = { // expects data in coma separated format of @@ -148,12 +151,25 @@ case class CsvStdFeatureSource[V](data: RDD[String], .drop(1) } + def parseTypes(typeRDD: RDD[String]): Map[String, VariableType] = { + typeRDD + .mapPartitions { it => + val csvParser = new CSVParser(csvFormat) + it.map(csvParser.parseLine(_).get).map(l => (l.head, VariableType.fromString(l.last))) + } + .collect() + .toMap + } + def features: RDD[Feature] = featuresAs[Vector] def featuresAs[T](implicit cr: DataBuilder[T]): RDD[Feature] = { - transposedData.map({ + val types = br_types.value + val representationFactory = DefRepresentationFactory + transposedData.map { case (varId, values) => - StdFeature.from[T](varId, defaultType, values.toList) - }) + val variableType = types.flatMap(_.get(varId)).getOrElse(defaultType) + StdFeature.from[T](varId, variableType, values.toList) + } } }