Skip to content

Commit

Permalink
Merge pull request #72 from data-catering/4-generate-data-for-all-pos…
Browse files Browse the repository at this point in the history
…sible-combinations

4 generate data for all possible combinations
  • Loading branch information
pflooky authored Sep 13, 2024
2 parents d249a35 + 192cead commit faec9e2
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 77 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ jobs:
- name: Run integration tests
id: tests
uses: data-catering/insta-integration@v1
env:
LOG_LEVEL: debug
- name: Print results
run: |
echo "Records generated: ${{ steps.tests.outputs.num_records_generated }}"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package io.github.datacatering.datacaterer.api.connection

import io.github.datacatering.datacaterer.api.model.Constants.{ENABLE_DATA_VALIDATION, FORMAT}
import io.github.datacatering.datacaterer.api.model.Constants.{ALL_COMBINATIONS, ENABLE_DATA_VALIDATION, FORMAT}
import io.github.datacatering.datacaterer.api.{ConnectionConfigWithTaskBuilder, CountBuilder, FieldBuilder, GeneratorBuilder, MetadataSourceBuilder, SchemaBuilder, StepBuilder, TaskBuilder, TasksBuilder, ValidationBuilder, WaitConditionBuilder}
import io.github.datacatering.datacaterer.api.model.{Step, Task}

Expand Down Expand Up @@ -56,6 +56,11 @@ trait ConnectionTaskBuilder[T] {
this
}

def allCombinations(enable: Boolean): ConnectionTaskBuilder[T] = {
this.step = Some(getStep.option(ALL_COMBINATIONS, enable.toString))
this
}

def numPartitions(numPartitions: Int): ConnectionTaskBuilder[T] = {
this.step = Some(getStep.numPartitions(numPartitions))
this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ object Constants {
lazy val HTTP_PARAMETER_TYPE = "httpParamType"
lazy val POST_SQL_EXPRESSION = "postSqlExpression"

//step options
lazy val ALL_COMBINATIONS = "allCombinations"

//field labels
lazy val LABEL_NAME = "name"
lazy val LABEL_USERNAME = "username"
Expand Down Expand Up @@ -209,6 +212,7 @@ object Constants {
"spark.sql.cbo.planStats.enabled" -> "true",
"spark.sql.legacy.allowUntypedScalaUDF" -> "true",
"spark.sql.legacy.allowParameterlessCount" -> "true",
"spark.sql.legacy.allowParameterlessCount" -> "true",
"spark.sql.statistics.histogram.enabled" -> "true",
"spark.sql.shuffle.partitions" -> "10",
"spark.sql.catalog.postgres" -> "",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package io.github.datacatering.datacaterer.api

import io.github.datacatering.datacaterer.api.model.Constants.FOREIGN_KEY_DELIMITER
import io.github.datacatering.datacaterer.api.model.Constants.{ALL_COMBINATIONS, FOREIGN_KEY_DELIMITER}
import io.github.datacatering.datacaterer.api.connection.FileBuilder
import io.github.datacatering.datacaterer.api.model.{DataCatererConfiguration, ExpressionValidation, ForeignKeyRelation, PauseWaitCondition}
import org.junit.runner.RunWith
Expand Down Expand Up @@ -215,4 +215,14 @@ class PlanBuilderTest extends AnyFunSuite {
assert(fk.head._2.isEmpty)
assert(fk.head._3.size == 1)
}

test("Can create a step that will generate records for all combinations") {
val jsonTask = ConnectionConfigWithTaskBuilder().file("my_json", "json")
.allCombinations(true)

assert(jsonTask.step.isDefined)
assert(jsonTask.step.get.step.options.nonEmpty)
assert(jsonTask.step.get.step.options.contains(ALL_COMBINATIONS))
assert(jsonTask.step.get.step.options(ALL_COMBINATIONS).equalsIgnoreCase("true"))
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
package io.github.datacatering.datacaterer.core.generator

import io.github.datacatering.datacaterer.api.model.Constants.SQL_GENERATOR
import io.github.datacatering.datacaterer.api.model.Constants.{ALL_COMBINATIONS, ONE_OF_GENERATOR, SQL_GENERATOR}
import io.github.datacatering.datacaterer.api.model.{Field, PerColumnCount, Step}
import io.github.datacatering.datacaterer.core.exception.InvalidStepCountGeneratorConfigurationException
import io.github.datacatering.datacaterer.core.generator.provider.DataGenerator
import io.github.datacatering.datacaterer.core.generator.provider.OneOfDataGenerator.RandomOneOfDataGenerator
import io.github.datacatering.datacaterer.core.model.Constants._
import io.github.datacatering.datacaterer.core.util.GeneratorUtil.{applySqlExpressions, getDataGenerator}
import io.github.datacatering.datacaterer.core.util.ObjectMapperUtil
import io.github.datacatering.datacaterer.core.util.PlanImplicits.FieldOps
import net.datafaker.Faker
import org.apache.log4j.Logger
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
Expand All @@ -21,27 +23,33 @@ case class Holder(__index_inc: Long)

class DataGeneratorFactory(faker: Faker)(implicit val sparkSession: SparkSession) {

private val LOGGER = Logger.getLogger(getClass.getName)
private val OBJECT_MAPPER = ObjectMapperUtil.jsonObjectMapper
registerSparkFunctions()

def generateDataForStep(step: Step, dataSourceName: String, startIndex: Long, endIndex: Long): DataFrame = {
//need to have separate code for generating all possible combinations
val structFieldsWithDataGenerators = step.schema.fields.map(getStructWithGenerators).getOrElse(List())
val indexedDf = sparkSession.createDataFrame(Seq.range(startIndex, endIndex).map(Holder))
generateDataViaSql(structFieldsWithDataGenerators, step, indexedDf)
.alias(s"$dataSourceName.${step.name}")
}

private def generateDataViaSql(dataGenerators: List[DataGenerator[_]], step: Step, indexedDf: DataFrame): DataFrame = {
val structType = StructType(dataGenerators.map(_.structField))
val genSqlExpression = dataGenerators.map(dg => s"${dg.generateSqlExpressionWrapper} AS `${dg.structField.name}`")
val df = indexedDf.selectExpr(genSqlExpression: _*)
val allRecordsDf = if (step.options.contains(ALL_COMBINATIONS) && step.options(ALL_COMBINATIONS).equalsIgnoreCase("true")) {
generateCombinationRecords(dataGenerators, indexedDf)
} else {
val genSqlExpression = dataGenerators.map(dg => s"${dg.generateSqlExpressionWrapper} AS `${dg.structField.name}`")
val df = indexedDf.selectExpr(genSqlExpression: _*)

val perColDf = step.count.perColumn
.map(perCol => generateRecordsPerColumn(dataGenerators, step, perCol, df))
.getOrElse(df)
if (!perColDf.storageLevel.useMemory) perColDf.cache()
step.count.perColumn
.map(perCol => generateRecordsPerColumn(dataGenerators, step, perCol, df))
.getOrElse(df)
}

val dfWithMetadata = attachMetadata(perColDf, structType)
if (!allRecordsDf.storageLevel.useMemory) allRecordsDf.cache()
val structType = StructType(dataGenerators.map(_.structField))
val dfWithMetadata = attachMetadata(allRecordsDf, structType)
val dfAllFields = attachMetadata(applySqlExpressions(dfWithMetadata), structType)
if (!dfAllFields.storageLevel.useMemory) dfAllFields.cache()
dfAllFields
Expand Down Expand Up @@ -100,6 +108,31 @@ class DataGeneratorFactory(faker: Faker)(implicit val sparkSession: SparkSession
explodeCount.select(PER_COLUMN_INDEX_COL + ".*", perColumnCount.columnNames: _*)
}

private def generateCombinationRecords(dataGenerators: List[DataGenerator[_]], indexedDf: DataFrame) = {
LOGGER.debug("Attempting to generate all combinations of 'oneOf' fields")
//TODO could be nested oneOf fields
val oneOfFields = dataGenerators
.filter(x => x.isInstanceOf[RandomOneOfDataGenerator] || x.options.contains(ONE_OF_GENERATOR))
val nonOneOfFields = dataGenerators.filter(x => !x.isInstanceOf[RandomOneOfDataGenerator] && !x.options.contains(ONE_OF_GENERATOR))

val oneOfFieldsSql = oneOfFields.map(field => {
val fieldValues = field.structField.metadata.getStringArray(ONE_OF_GENERATOR)
sparkSession.createDataFrame(Seq(1L).map(Holder))
.selectExpr(explode(typedlit(fieldValues)).as(field.structField.name).expr.sql)
})
val nonOneOfFieldsSql = nonOneOfFields.map(dg => s"${dg.generateSqlExpressionWrapper} AS `${dg.structField.name}`")

if (oneOfFields.nonEmpty) {
LOGGER.debug("Found fields defined with 'oneOf', attempting to create all combinations of possible values")
val pairwiseCombinations = oneOfFieldsSql.reduce((a, b) => a.crossJoin(b))
val selectExpr = pairwiseCombinations.columns.toList ++ nonOneOfFieldsSql
pairwiseCombinations.selectExpr(selectExpr: _*)
} else {
LOGGER.debug("No fields defined with 'oneOf', unable to create all possible combinations")
indexedDf
}
}

private def generateDataWithSchema(dataGenerators: List[DataGenerator[_]]): UserDefinedFunction = {
udf((sqlGen: Int) => {
(1L to sqlGen)
Expand Down Expand Up @@ -132,58 +165,59 @@ class DataGeneratorFactory(faker: Faker)(implicit val sparkSession: SparkSession
}

private def defineRandomLengthView(): Unit = {
sparkSession.sql(s"""WITH lengths AS (
| SELECT sequence(1, $DATA_CATERER_RANDOM_LENGTH_MAX_VALUE) AS length_list
|),
|
|-- Explode the sequence into individual length values
|exploded_lengths AS (
| SELECT explode(length_list) AS length
| FROM lengths
|),
|
|-- Create the heuristic cumulative distribution dynamically
|length_distribution AS (
| SELECT
| length,
| CASE
| WHEN length <= 5 THEN 0.001 * POWER(2, length - 1)
| WHEN length <= 10 THEN 0.01 * POWER(2, length - 6)
| ELSE 0.1 * POWER(2, length - 11)
| END AS weight,
| SUM(CASE
| WHEN length <= 5 THEN 0.001 * POWER(2, length - 1)
| WHEN length <= 10 THEN 0.01 * POWER(2, length - 6)
| ELSE 0.1 * POWER(2, length - 11)
| END) OVER (ORDER BY length) AS cumulative_weight,
| SUM(CASE
| WHEN length <= 5 THEN 0.001 * POWER(2, length - 1)
| WHEN length <= 10 THEN 0.01 * POWER(2, length - 6)
| ELSE 0.1 * POWER(2, length - 11)
| END) OVER () AS total_weight
| FROM exploded_lengths
|),
|
|-- Calculate cumulative probabilities
|length_probabilities AS (
| SELECT
| length,
| cumulative_weight / total_weight AS cumulative_prob
| FROM length_distribution
|),
|
|-- Select a single random length based on the heuristic distribution
|random_length AS (
| SELECT
| length
| FROM length_probabilities
| WHERE cumulative_prob >= rand()
| ORDER BY cumulative_prob
| LIMIT 1
|)
|
|-- Final query to get the single random length
|SELECT * FROM random_length;""".stripMargin)
sparkSession.sql(
s"""WITH lengths AS (
| SELECT sequence(1, $DATA_CATERER_RANDOM_LENGTH_MAX_VALUE) AS length_list
|),
|
|-- Explode the sequence into individual length values
|exploded_lengths AS (
| SELECT explode(length_list) AS length
| FROM lengths
|),
|
|-- Create the heuristic cumulative distribution dynamically
|length_distribution AS (
| SELECT
| length,
| CASE
| WHEN length <= 5 THEN 0.001 * POWER(2, length - 1)
| WHEN length <= 10 THEN 0.01 * POWER(2, length - 6)
| ELSE 0.1 * POWER(2, length - 11)
| END AS weight,
| SUM(CASE
| WHEN length <= 5 THEN 0.001 * POWER(2, length - 1)
| WHEN length <= 10 THEN 0.01 * POWER(2, length - 6)
| ELSE 0.1 * POWER(2, length - 11)
| END) OVER (ORDER BY length) AS cumulative_weight,
| SUM(CASE
| WHEN length <= 5 THEN 0.001 * POWER(2, length - 1)
| WHEN length <= 10 THEN 0.01 * POWER(2, length - 6)
| ELSE 0.1 * POWER(2, length - 11)
| END) OVER () AS total_weight
| FROM exploded_lengths
|),
|
|-- Calculate cumulative probabilities
|length_probabilities AS (
| SELECT
| length,
| cumulative_weight / total_weight AS cumulative_prob
| FROM length_distribution
|),
|
|-- Select a single random length based on the heuristic distribution
|random_length AS (
| SELECT
| length
| FROM length_probabilities
| WHERE cumulative_prob >= rand()
| ORDER BY cumulative_prob
| LIMIT 1
|)
|
|-- Final query to get the single random length
|SELECT * FROM random_length;""".stripMargin)
.createOrReplaceTempView(DATA_CATERER_RANDOM_LENGTH)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package io.github.datacatering.datacaterer.core.generator.provider

import io.github.datacatering.datacaterer.api.model.Constants.{ARRAY_MAXIMUM_LENGTH, ARRAY_MINIMUM_LENGTH, ENABLED_EDGE_CASE, ENABLED_NULL, IS_UNIQUE, PROBABILITY_OF_EDGE_CASE, PROBABILITY_OF_NULL, RANDOM_SEED, STATIC}
import io.github.datacatering.datacaterer.api.model.generator.BaseGenerator
import io.github.datacatering.datacaterer.core.model.Constants.DATA_CATERER_RANDOM_LENGTH
import net.datafaker.Faker
import org.apache.spark.sql.functions.{expr, rand, when}
import org.apache.spark.sql.types.StructField

import java.util.regex.Pattern
import scala.annotation.tailrec
import scala.collection.mutable
import scala.language.higherKinds
Expand Down Expand Up @@ -49,7 +51,9 @@ trait DataGenerator[T] extends BaseGenerator[T] with Serializable {
.expr.sql
case _ => baseSqlExpression
}
replaceLambdaFunction(expression)
val replaceLambda = replaceLambdaFunction(expression)
val replaceSubScalar = replaceSubScalarFunction(replaceLambda, baseSqlExpression)
replaceSubScalar
}

def generateWrapper(count: Int = 0): T = {
Expand All @@ -75,16 +79,33 @@ trait DataGenerator[T] extends BaseGenerator[T] with Serializable {
}
}

@tailrec
private def replaceLambdaFunction(sql: String): String = {
val lambdaRegex = ".*lambdafunction\\((.+?), i\\).*".r.pattern
val matcher = lambdaRegex.matcher(sql)
val replaceTargetFn: String => String = r => s"lambdafunction($r, i)"
val replacementFn: String => String = r => s"i -> $r"
replaceByRegex(sql, lambdaRegex, replaceTargetFn, replacementFn)
}

private def replaceSubScalarFunction(sql: String, originalSql: String): String = {
val lambdaRegex = ".*scalarsubquery\\((.*?)\\).*".r.pattern
val replaceTargetFn: String => String = r => s"scalarsubquery()"
val originalRegex = s".*\\(SELECT CAST\\((.+?) $DATA_CATERER_RANDOM_LENGTH\\).*".r.pattern
val matcher = originalRegex.matcher(originalSql)
if (matcher.matches()) {
val innerFunction = matcher.group(1)
val replace = sql.replace(s"lambdafunction($innerFunction, i)", s"i -> $innerFunction")
replaceLambdaFunction(replace)
val replacementFn: String => String = _ => s"(SELECT CAST(${matcher.group(1)} $DATA_CATERER_RANDOM_LENGTH)"
replaceByRegex(sql, lambdaRegex, replaceTargetFn, replacementFn)
} else sql
}

@tailrec
private def replaceByRegex(text: String, pattern: Pattern, replaceTargetFn: String => String, replacementFn: String => String): String = {
val matcher = pattern.matcher(text)
if (matcher.matches()) {
val innerFunction = matcher.group(1)
val replace = text.replace(replaceTargetFn(innerFunction), replacementFn(innerFunction))
replaceByRegex(replace, pattern, replaceTargetFn, replacementFn)
} else text
}
}

trait NullableDataGenerator[T >: Null] extends DataGenerator[T] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ object GeneratorUtil {
case x => throw new UnsupportedDataGeneratorType(x)
}
} else {
LOGGER.debug(s"No generator defined, will default to random generator, field-name=${structField.name}")
RandomDataGenerator.getGeneratorForStructField(structField, faker)
LOGGER.debug(s"No generator defined, will get type of generator based on field options, field-name=${structField.name}")
getDataGenerator(structField, faker)
}
}

Expand Down
Loading

0 comments on commit faec9e2

Please sign in to comment.