diff --git a/CHANGELOG.md b/CHANGELOG.md index 57878f7f4..fd2cf9b1a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,11 +6,10 @@ ### Fixes - - Cast `table_owner` to string to avoid errors generating docs ([#158](https://github.com/fishtown-analytics/dbt-spark/pull/158), [#159](https://github.com/fishtown-analytics/dbt-spark/pull/159)) +- Explicitly cast column types when inserting seeds ([#139](https://github.com/fishtown-analytics/dbt-spark/pull/139), [#166](https://github.com/fishtown-analytics/dbt-spark/pull/166)) ### Under the hood - - Parse information returned by `list_relations_without_caching` macro to speed up catalog generation ([#93](https://github.com/fishtown-analytics/dbt-spark/issues/93), [#160](https://github.com/fishtown-analytics/dbt-spark/pull/160)) - More flexible host passing, https:// can be omitted ([#153](https://github.com/fishtown-analytics/dbt-spark/issues/153)) diff --git a/dbt/include/spark/macros/materializations/seed.sql b/dbt/include/spark/macros/materializations/seed.sql index c857f013b..795f49329 100644 --- a/dbt/include/spark/macros/materializations/seed.sql +++ b/dbt/include/spark/macros/materializations/seed.sql @@ -1,6 +1,7 @@ {% macro spark__load_csv_rows(model, agate_table) %} {% set batch_size = 1000 %} - + {% set column_override = model['config'].get('column_types', {}) %} + {% set statements = [] %} {% for chunk in agate_table.rows | batch(batch_size) %} @@ -13,12 +14,10 @@ {% set sql %} insert into {{ this.render() }} values {% for row in chunk -%} - ({%- for column in agate_table.columns -%} - {%- if 'ISODate' in (column.data_type | string) -%} - cast(%s as timestamp) - {%- else -%} - %s - {%- endif -%} + ({%- for col_name in agate_table.column_names -%} + {%- set inferred_type = adapter.convert_type(agate_table, loop.index0) -%} + {%- set type = column_override.get(col_name, inferred_type) -%} + cast(%s as {{type}}) {%- if not loop.last%},{%- endif %} {%- endfor -%}) {%- if not loop.last%},{%- endif %} diff --git a/test/custom/seed_column_types/data/payments.csv b/test/custom/seed_column_types/data/payments.csv new file mode 100644 index 000000000..3f49d788c --- /dev/null +++ b/test/custom/seed_column_types/data/payments.csv @@ -0,0 +1,11 @@ +ID,ORDERID,PAYMENTMETHOD,STATUS,AMOUNT,AMOUNT_USD,CREATED +1,1,credit_card,success,1000,10.00,2018-01-01 +2,2,credit_card,success,2000,20.00,2018-01-02 +3,3,coupon,success,100,1.00,2018-01-04 +4,4,coupon,success,2500,25.00,2018-01-05 +5,5,bank_transfer,fail,1700,17.00,2018-01-05 +6,5,bank_transfer,success,1700,17.00,2018-01-05 +7,6,credit_card,success,600,6.00,2018-01-07 +8,7,credit_card,success,1600,16.00,2018-01-09 +9,8,credit_card,success,2300,23.00,2018-01-11 +10,9,gift_card,success,2300,23.00,2018-01-12 diff --git a/test/custom/seed_column_types/test_seed_column_types.py b/test/custom/seed_column_types/test_seed_column_types.py new file mode 100644 index 000000000..e1fc32788 --- /dev/null +++ b/test/custom/seed_column_types/test_seed_column_types.py @@ -0,0 +1,36 @@ +from cProfile import run +from test.custom.base import DBTSparkIntegrationTest, use_profile +import dbt.exceptions + + +class TestSeedColumnTypeCast(DBTSparkIntegrationTest): + @property + def schema(self): + return "seed_column_types" + + @property + def models(self): + return "models" + + @property + def project_config(self): + return { + 'seeds': { + 'quote_columns': False, + }, + } + + # runs on Spark v2.0 + @use_profile("apache_spark") + def test_seed_column_types_apache_spark(self): + self.run_dbt(["seed"]) + + # runs on Spark v3.0 + @use_profile("databricks_cluster") + def test_seed_column_types_databricks_cluster(self): + self.run_dbt(["seed"]) + + # runs on Spark v3.0 + @use_profile("databricks_sql_endpoint") + def test_seed_column_types_databricks_sql_endpoint(self): + self.run_dbt(["seed"])