diff --git a/sqlframe/base/dataframe.py b/sqlframe/base/dataframe.py index 21f541c..2ee1714 100644 --- a/sqlframe/base/dataframe.py +++ b/sqlframe/base/dataframe.py @@ -1102,7 +1102,7 @@ def withColumn(self, colName: str, col: Column) -> Self: expression = self.expression.copy() expression.expressions[existing_col_index] = col.alias(col_name).expression return self.copy(expression=expression) - return self.copy().select(col.alias(col_name), append=True) + return self.select.__wrapped__(self, col.alias(col_name), append=True) # type: ignore @operation(Operation.SELECT) def withColumnRenamed(self, existing: str, new: str) -> Self: diff --git a/sqlframe/base/normalize.py b/sqlframe/base/normalize.py index 8c8f6da..780343c 100644 --- a/sqlframe/base/normalize.py +++ b/sqlframe/base/normalize.py @@ -20,7 +20,6 @@ def normalize(session: SESSION, expression_context: exp.Select, expr: t.List[NOR expr = ensure_list(expr) expressions = _ensure_expressions(expr) for expression in expressions: - # normalize_identifiers(expression, session.input_dialect) identifiers = expression.find_all(exp.Identifier) for identifier in identifiers: identifier.transform(session.input_dialect.normalize_identifier) diff --git a/tests/unit/standalone/test_dataframe.py b/tests/unit/standalone/test_dataframe.py index 844ecd2..0923c25 100644 --- a/tests/unit/standalone/test_dataframe.py +++ b/tests/unit/standalone/test_dataframe.py @@ -57,6 +57,25 @@ def test_with_column_duplicate_alias(standalone_employee: StandaloneDataFrame): ) +# https://github.com/eakmanrq/sqlframe/issues/19 +def test_with_column_dual_expression(standalone_employee: StandaloneDataFrame): + df1 = standalone_employee.withColumn("new_col1", standalone_employee.age) + df2 = df1.withColumn("new_col2", standalone_employee.store_id) + assert df2.columns == [ + "employee_id", + "fname", + "lname", + "age", + "store_id", + "new_col1", + "new_col2", + ] + assert ( + df2.sql(pretty=False) + == "SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id`, `a1`.`age` AS `new_col1`, `a1`.`store_id` AS `new_col2` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + ) + + def test_where_expr(standalone_employee: StandaloneDataFrame): df = standalone_employee.where("fname = 'Jack' AND age = 37") assert df.columns == ["employee_id", "fname", "lname", "age", "store_id"]