diff --git a/butterfree/configs/db/cassandra_config.py b/butterfree/configs/db/cassandra_config.py index d576359c..6d7f9a20 100644 --- a/butterfree/configs/db/cassandra_config.py +++ b/butterfree/configs/db/cassandra_config.py @@ -238,9 +238,9 @@ def translate(self, schema: List[Dict[str, Any]]) -> List[Dict[str, Any]]: "integertype": "int", "longtype": "bigint", "stringtype": "text", - "arraytype(longtype,true)": "frozen>", - "arraytype(stringtype,true)": "frozen>", - "arraytype(floattype,true)": "frozen>", + "arraytype(longtype, true)": "frozen>", + "arraytype(stringtype, true)": "frozen>", + "arraytype(floattype, true)": "frozen>", } cassandra_schema = [] for features in schema: diff --git a/tests/unit/butterfree/migrations/database_migration/conftest.py b/tests/unit/butterfree/migrations/database_migration/conftest.py index 237158b7..fed8528a 100644 --- a/tests/unit/butterfree/migrations/database_migration/conftest.py +++ b/tests/unit/butterfree/migrations/database_migration/conftest.py @@ -1,4 +1,4 @@ -from pyspark.sql.types import DoubleType, FloatType, LongType, TimestampType +from pyspark.sql.types import DoubleType, FloatType, LongType, TimestampType, ArrayType, StringType from pytest import fixture from butterfree.constants import DataType @@ -30,6 +30,7 @@ def fs_schema(): {"column_name": "id", "type": LongType(), "primary_key": True}, {"column_name": "timestamp", "type": TimestampType(), "primary_key": True}, {"column_name": "new_feature", "type": FloatType(), "primary_key": False}, + {"column_name": "array_feature", "type": ArrayType(StringType(),True), "primary_key": False}, { "column_name": "feature1__avg_over_1_week_rolling_windows", "type": FloatType(), diff --git a/tests/unit/butterfree/migrations/database_migration/test_cassandra_migration.py b/tests/unit/butterfree/migrations/database_migration/test_cassandra_migration.py index 5666cc47..5e89b65b 100644 --- a/tests/unit/butterfree/migrations/database_migration/test_cassandra_migration.py +++ b/tests/unit/butterfree/migrations/database_migration/test_cassandra_migration.py @@ -33,9 +33,11 @@ def test_create_table_query(self, fs_schema): expected_query = [ "CREATE TABLE test.table_name " "(id LongType, timestamp TimestampType, new_feature FloatType, " + "array_feature ArrayType(StringType(), True), " "feature1__avg_over_1_week_rolling_windows FloatType, " "PRIMARY KEY (id, timestamp));" ] + query = cassandra_migration.create_query(fs_schema, "table_name") assert query, expected_query