Skip to content

Commit

Permalink
feat: add openai enrichment support (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
eakmanrq authored May 28, 2024
1 parent d812c13 commit 5369535
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 4 deletions.
44 changes: 44 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,47 @@ The dialect of the generated SQL will be based on the session's dialect. However
# create session and `df` like normal
df.sql(dialect="bigquery")
```

### OpenAI Enriched

OpenAI's models can be used to enrich the generated SQL to make it more human-like.
This is useful when you want to generate SQL that is more readable for humans.
You must have `OPENAI_API_KEY` set in your environment variables to use this feature.

```python
# create session and `df` like normal
# The model to use defaults to `gpt-4o` but can be changed by passing a string to the `openai_model` parameter.
>>> df.sql(optimize=False, use_openai=True)
WITH natality_data AS (
SELECT
year,
ever_born
FROM `bigquery-public-data`.`samples`.`natality`
), single_child_families AS (
SELECT
year,
COUNT(*) AS num_single_child_families
FROM natality_data
WHERE ever_born = 1
GROUP BY year
), lagged_families AS (
SELECT
year,
num_single_child_families,
LAG(num_single_child_families, 1) OVER (ORDER BY year) AS last_year_num_single_child_families
FROM single_child_families
), percent_change_families AS (
SELECT
year,
num_single_child_families,
((num_single_child_families - last_year_num_single_child_families) / last_year_num_single_child_families) AS percent_change
FROM lagged_families
ORDER BY ABS(percent_change) DESC
)
SELECT
year,
FORMAT('%\'.0f', ROUND(CAST(num_single_child_families AS FLOAT64), 0)) AS `new families single child`,
FORMAT('%\'.2f', ROUND(CAST((percent_change * 100) AS FLOAT64), 2)) AS `percent change`
FROM percent_change_families
LIMIT 5
```
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"dev": [
"duckdb>=0.9,<0.11",
"mypy>=1.10.0,<1.11",
"openai>=1.30,<1.31",
"pandas>=2,<3",
"pandas-stubs>=2,<3",
"psycopg>=3.1,<4",
Expand All @@ -55,6 +56,9 @@
"duckdb>=0.9,<0.11",
"pandas>=2,<3",
],
"openai": [
"openai>=1.30,<1.31",
],
"pandas": [
"pandas>=2,<3",
],
Expand Down
46 changes: 42 additions & 4 deletions sqlframe/base/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from sqlglot import Dialect
from sqlglot import expressions as exp
from sqlglot.helper import ensure_list, object_to_dict, seq_get
from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify import qualify
from sqlglot.optimizer.qualify_columns import quote_identifiers

from sqlframe.base.decorators import normalize
Expand All @@ -24,6 +26,7 @@
get_func_from_session,
get_tables_from_expression_with_join,
quote_preserving_alias_or_name,
verify_openai_installed,
)

if sys.version_info >= (3, 11):
Expand Down Expand Up @@ -473,6 +476,8 @@ def sql(
dialect: DialectType = None,
optimize: bool = True,
pretty: bool = True,
use_openai: bool = False,
openai_model: str = "gpt-4o",
as_list: bool = False,
**kwargs,
) -> t.Union[str, t.List[str]]:
Expand All @@ -492,6 +497,9 @@ def sql(
select_expression = t.cast(
exp.Select, self.session._optimize(select_expression, dialect=dialect)
)
elif use_openai:
qualify(select_expression, dialect=dialect, schema=self.session.catalog._schema)
pushdown_projections(select_expression, schema=self.session.catalog._schema)

select_expression = df._replace_cte_names_with_hashes(select_expression)

Expand Down Expand Up @@ -545,10 +553,40 @@ def sql(

output_expressions.append(expression)

results = [
expression.sql(dialect=dialect, pretty=pretty, **kwargs)
for expression in output_expressions
]
results = []
for expression in output_expressions:
sql = expression.sql(dialect=dialect, pretty=pretty, **kwargs)
if use_openai:
verify_openai_installed()
from openai import OpenAI

client = OpenAI()
prompt = f"""
You are a backend tool that converts correct {dialect} SQL to simplified and more human readable version.
You respond without code block with rewritten {dialect} SQL.
You don't change any column names in the final select because the user expects those to remain the same.
You make unique CTE alias names match what a human would write and in snake case.
You improve formatting with spacing and line-breaks.
You remove redundant parenthesis and aliases.
When remove extra quotes, make sure to keep quotes around words that could be reserved words
"""
chat_completed = client.chat.completions.create(
messages=[
{
"role": "system",
"content": prompt,
},
{
"role": "user",
"content": sql,
},
],
model=openai_model,
)
assert chat_completed.choices[0].message.content is not None
sql = chat_completed.choices[0].message.content
results.append(sql)

if as_list:
return results
return ";\n".join(results)
Expand Down
9 changes: 9 additions & 0 deletions sqlframe/base/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,15 @@ def verify_pandas_installed():
)


def verify_openai_installed():
try:
import openai # noqa
except ImportError:
raise ImportError(
"""OpenAI is required for this functionality. `pip install "sqlframe[openai]"` (also include your engine if needed) to install openai."""
)


def quote_preserving_alias_or_name(col: t.Union[exp.Column, exp.Alias]) -> str:
from sqlframe.base.session import _BaseSession

Expand Down

0 comments on commit 5369535

Please sign in to comment.