Skip to content

Commit

Permalink
Postgres: Add window function support (#4283)
Browse files Browse the repository at this point in the history
* Postgresql: Support window function

* Use snapshot repo

* Postgresql: Support window function

* Use long as migration version

* Add integration test

* Bump to sql-psi 0.4.5

* Fix int usage

* Fix dialect test

* Fix dialect tests

* Fix schema version

* Fix schema version

* Fix tests

* Update settings.gradle

* Remove snapshot repo

* Remove snapshot repo

---------

Co-authored-by: hfhbd <hfhbd@users.noreply.github.com>
Co-authored-by: Alec Kazakova <AlecStrong@users.noreply.github.com>
  • Loading branch information
3 people committed Jun 27, 2023
1 parent d597342 commit 50c2ae2
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ import app.cash.sqldelight.dialects.postgresql.PostgreSqlType.BIG_INT
import app.cash.sqldelight.dialects.postgresql.PostgreSqlType.SMALL_INT
import app.cash.sqldelight.dialects.postgresql.PostgreSqlType.TIMESTAMP
import app.cash.sqldelight.dialects.postgresql.PostgreSqlType.TIMESTAMP_TIMEZONE
import app.cash.sqldelight.dialects.postgresql.grammar.mixins.WindowFunctionMixin
import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlDeleteStmtLimited
import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlExtensionExpr
import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlInsertStmt
import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlTypeName
import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlUpdateStmtLimited
Expand Down Expand Up @@ -90,7 +92,19 @@ class PostgreSqlTypeResolver(private val parentResolver: TypeResolver) : TypeRes
"min" -> encapsulatingType(exprList, BLOB, TEXT, SMALL_INT, INTEGER, PostgreSqlType.INTEGER, BIG_INT, REAL, TIMESTAMP_TIMEZONE, TIMESTAMP).asNullable()
"date_trunc" -> encapsulatingType(exprList, TIMESTAMP_TIMEZONE, TIMESTAMP)
"date_part" -> IntermediateType(REAL)
"percentile_disc" -> IntermediateType(REAL).asNullable()
"now" -> IntermediateType(TIMESTAMP_TIMEZONE)
"corr", "covar_pop", "covar_samp", "regr_avgx", "regr_avgy", "regr_intercept",
"regr_r2", "regr_slope", "regr_sxx", "regr_sxy", "regr_syy",
-> IntermediateType(REAL).asNullable()
"stddev", "stddev_pop", "stddev_samp", "variance",
"var_pop", "var_samp",
-> if (resolvedType(exprList[0]).dialectType == REAL) {
IntermediateType(REAL).asNullable()
} else IntermediateType(
PostgreSqlType.NUMERIC,
).asNullable()
"regr_count" -> IntermediateType(BIG_INT).asNullable()
"gen_random_uuid" -> IntermediateType(PostgreSqlType.UUID)
"length", "character_length", "char_length" -> IntermediateType(PostgreSqlType.INTEGER).nullableIf(resolvedType(exprList[0]).javaType.isNullable)
else -> null
Expand Down Expand Up @@ -141,6 +155,13 @@ class PostgreSqlTypeResolver(private val parentResolver: TypeResolver) : TypeRes
literalValue.text.startsWith("INTERVAL") -> IntermediateType(PostgreSqlType.INTERVAL)
else -> parentResolver.resolvedType(this)
}
is PostgreSqlExtensionExpr -> when {
windowFunctionExpr != null -> {
val windowFunctionExpr = windowFunctionExpr as WindowFunctionMixin
functionType(windowFunctionExpr.functionExpr)!!
}
else -> parentResolver.resolvedType(this)
}

else -> parentResolver.resolvedType(this)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"static com.alecstrong.sql.psi.core.psi.SqlTypes.FOREIGN"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.FROM"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.GENERATED"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.GROUP"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.ID"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.IGNORE"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.INSERT"
Expand Down Expand Up @@ -310,12 +311,16 @@ compound_select_stmt ::= [ {with_clause} ] {select_stmt} ( {compound_operator}
override = true
}

extension_expr ::= json_expression | boolean_literal | boolean_not_expression {
extension_expr ::= json_expression | boolean_literal | boolean_not_expression | window_function_expr {
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlExtensionExprImpl"
implements = "com.alecstrong.sql.psi.core.psi.SqlExtensionExpr"
override = true
}

window_function_expr ::= {function_expr} 'WITHIN' GROUP LP ORDER BY <<expr '-1'>> ( COMMA <<expr '-1'>> ) * RP {
mixin = "app.cash.sqldelight.dialects.postgresql.grammar.mixins.WindowFunctionMixin"
}

boolean_not_expression ::= NOT (boolean_literal | {column_name})

boolean_literal ::= TRUE | FALSE
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package app.cash.sqldelight.dialects.postgresql.grammar.mixins

import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlWindowFunctionExpr
import com.alecstrong.sql.psi.core.psi.SqlCompositeElementImpl
import com.alecstrong.sql.psi.core.psi.SqlFunctionExpr
import com.intellij.lang.ASTNode

internal abstract class WindowFunctionMixin(
node: ASTNode,
) : SqlCompositeElementImpl(node),
PostgreSqlWindowFunctionExpr {
val functionExpr get() = children.filterIsInstance<SqlFunctionExpr>().single()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
CREATE TABLE myTable(
myColumn REAL NOT NULL
);

SELECT percentile_disc(.5) WITHIN GROUP (ORDER BY myTable.myColumn) AS P5
FROM myTable;
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
CREATE TABLE myTable(
foo REAL NOT NULL,
bar NUMERIC NOT NULL
);

SELECT
corr(foo),
stddev(bar),
stddev(foo),
regr_count(foo)
FROM myTable GROUP BY foo, bar;
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import com.intellij.psi.PsiElement
interface TypeResolver {
/**
* @param expr The expression to be resolved to a type.
* @return The type for [expr] for null if this resolver cannot solve.
* @return The resolved type
*/
fun resolvedType(expr: SqlExpr): IntermediateType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ const val SQLDELIGHT_EXTENSION = "sq"
object SqlDelightFileType : LanguageFileType(SqlDelightLanguage) {
private val ICON = AllIcons.Providers.Sqlite

const val FOLDER_NAME = "sqldelight"

override fun getName() = "SqlDelight"
override fun getDescription() = "SqlDelight"
override fun getDefaultExtension() = SQLDELIGHT_EXTENSION
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
CREATE TABLE myTable(
foo REAL NOT NULL,
bar NUMERIC NOT NULL
);

INSERT INTO myTable VALUES (1, 1), (2, 2), (3, 3);

selectPercentile:
SELECT percentile_disc(.5) WITHIN GROUP (ORDER BY foo) AS P5
FROM myTable;

selectStats:
SELECT
corr(foo, bar),
stddev(foo),
regr_count(foo, bar)
FROM myTable
GROUP BY foo, bar;
Original file line number Diff line number Diff line change
Expand Up @@ -312,4 +312,18 @@ class PostgreSqlTest {
val desc = database.charactersQueries.selectDescriptionLength().executeAsOne()
assertThat(desc.length).isNull()
}

@Test fun statFunctions() {
val percentile: SelectPercentile = database.functionsQueries.selectPercentile().executeAsOne()
val result: Double? = 2.0
assertThat(percentile).isEqualTo(SelectPercentile(result))
val stats: List<SelectStats> = database.functionsQueries.selectStats().executeAsList()
assertThat(stats).isEqualTo(
listOf(
SelectStats(null, null, 1),
SelectStats(null, null, 1),
SelectStats(null, null, 1),
),
)
}
}

0 comments on commit 50c2ae2

Please sign in to comment.