From 53695352f596e723575fb32bc435f83d9be9b694 Mon Sep 17 00:00:00 2001 From: Ryan Eakman <6326532+eakmanrq@users.noreply.github.com> Date: Mon, 27 May 2024 17:29:26 -0700 Subject: [PATCH] feat: add openai enrichment support (#35) --- docs/configuration.md | 44 ++++++++++++++++++++++++++++++++++++ setup.py | 4 ++++ sqlframe/base/dataframe.py | 46 ++++++++++++++++++++++++++++++++++---- sqlframe/base/util.py | 9 ++++++++ 4 files changed, 99 insertions(+), 4 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index c653428..c9526cb 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -183,3 +183,47 @@ The dialect of the generated SQL will be based on the session's dialect. However # create session and `df` like normal df.sql(dialect="bigquery") ``` + +### OpenAI Enriched + +OpenAI's models can be used to enrich the generated SQL to make it more human-like. +This is useful when you want to generate SQL that is more readable for humans. +You must have `OPENAI_API_KEY` set in your environment variables to use this feature. + +```python +# create session and `df` like normal +# The model to use defaults to `gpt-4o` but can be changed by passing a string to the `openai_model` parameter. +>>> df.sql(optimize=False, use_openai=True) +WITH natality_data AS ( + SELECT + year, + ever_born + FROM `bigquery-public-data`.`samples`.`natality` +), single_child_families AS ( + SELECT + year, + COUNT(*) AS num_single_child_families + FROM natality_data + WHERE ever_born = 1 + GROUP BY year +), lagged_families AS ( + SELECT + year, + num_single_child_families, + LAG(num_single_child_families, 1) OVER (ORDER BY year) AS last_year_num_single_child_families + FROM single_child_families +), percent_change_families AS ( + SELECT + year, + num_single_child_families, + ((num_single_child_families - last_year_num_single_child_families) / last_year_num_single_child_families) AS percent_change + FROM lagged_families + ORDER BY ABS(percent_change) DESC +) +SELECT + year, + FORMAT('%\'.0f', ROUND(CAST(num_single_child_families AS FLOAT64), 0)) AS `new families single child`, + FORMAT('%\'.2f', ROUND(CAST((percent_change * 100) AS FLOAT64), 2)) AS `percent change` +FROM percent_change_families +LIMIT 5 +``` diff --git a/setup.py b/setup.py index c5c184a..9d2ff36 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ "dev": [ "duckdb>=0.9,<0.11", "mypy>=1.10.0,<1.11", + "openai>=1.30,<1.31", "pandas>=2,<3", "pandas-stubs>=2,<3", "psycopg>=3.1,<4", @@ -55,6 +56,9 @@ "duckdb>=0.9,<0.11", "pandas>=2,<3", ], + "openai": [ + "openai>=1.30,<1.31", + ], "pandas": [ "pandas>=2,<3", ], diff --git a/sqlframe/base/dataframe.py b/sqlframe/base/dataframe.py index 4271f07..318b878 100644 --- a/sqlframe/base/dataframe.py +++ b/sqlframe/base/dataframe.py @@ -15,6 +15,8 @@ from sqlglot import Dialect from sqlglot import expressions as exp from sqlglot.helper import ensure_list, object_to_dict, seq_get +from sqlglot.optimizer.pushdown_projections import pushdown_projections +from sqlglot.optimizer.qualify import qualify from sqlglot.optimizer.qualify_columns import quote_identifiers from sqlframe.base.decorators import normalize @@ -24,6 +26,7 @@ get_func_from_session, get_tables_from_expression_with_join, quote_preserving_alias_or_name, + verify_openai_installed, ) if sys.version_info >= (3, 11): @@ -473,6 +476,8 @@ def sql( dialect: DialectType = None, optimize: bool = True, pretty: bool = True, + use_openai: bool = False, + openai_model: str = "gpt-4o", as_list: bool = False, **kwargs, ) -> t.Union[str, t.List[str]]: @@ -492,6 +497,9 @@ def sql( select_expression = t.cast( exp.Select, self.session._optimize(select_expression, dialect=dialect) ) + elif use_openai: + qualify(select_expression, dialect=dialect, schema=self.session.catalog._schema) + pushdown_projections(select_expression, schema=self.session.catalog._schema) select_expression = df._replace_cte_names_with_hashes(select_expression) @@ -545,10 +553,40 @@ def sql( output_expressions.append(expression) - results = [ - expression.sql(dialect=dialect, pretty=pretty, **kwargs) - for expression in output_expressions - ] + results = [] + for expression in output_expressions: + sql = expression.sql(dialect=dialect, pretty=pretty, **kwargs) + if use_openai: + verify_openai_installed() + from openai import OpenAI + + client = OpenAI() + prompt = f""" + You are a backend tool that converts correct {dialect} SQL to simplified and more human readable version. + You respond without code block with rewritten {dialect} SQL. + You don't change any column names in the final select because the user expects those to remain the same. + You make unique CTE alias names match what a human would write and in snake case. + You improve formatting with spacing and line-breaks. + You remove redundant parenthesis and aliases. + When remove extra quotes, make sure to keep quotes around words that could be reserved words + """ + chat_completed = client.chat.completions.create( + messages=[ + { + "role": "system", + "content": prompt, + }, + { + "role": "user", + "content": sql, + }, + ], + model=openai_model, + ) + assert chat_completed.choices[0].message.content is not None + sql = chat_completed.choices[0].message.content + results.append(sql) + if as_list: return results return ";\n".join(results) diff --git a/sqlframe/base/util.py b/sqlframe/base/util.py index 669f39b..b3223ce 100644 --- a/sqlframe/base/util.py +++ b/sqlframe/base/util.py @@ -256,6 +256,15 @@ def verify_pandas_installed(): ) +def verify_openai_installed(): + try: + import openai # noqa + except ImportError: + raise ImportError( + """OpenAI is required for this functionality. `pip install "sqlframe[openai]"` (also include your engine if needed) to install openai.""" + ) + + def quote_preserving_alias_or_name(col: t.Union[exp.Column, exp.Alias]) -> str: from sqlframe.base.session import _BaseSession