Skip to content

Commit

Permalink
Use IDENTIFIER when building snowpark queries
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-turbaszek committed Jul 12, 2024
1 parent 8938645 commit c69dbcf
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 58 deletions.
2 changes: 1 addition & 1 deletion RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
* The `snow app version create` command now allows operating on application packages created outside the CLI
* Added support for user stages in stage execute command
* Added support for user stages in stage and git copy commands

* Improved support for quoted identifiers in snowpark commands.

# v2.6.0
## Backward incompatibility
Expand Down
34 changes: 15 additions & 19 deletions src/snowflake/cli/plugins/snowpark/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
from snowflake.cli.plugins.object.manager import ObjectManager
from snowflake.cli.plugins.snowpark import package_utils
from snowflake.cli.plugins.snowpark.common import (
build_udf_sproc_identifier,
UdfSprocIdentifier,
check_if_replace_is_required,
)
from snowflake.cli.plugins.snowpark.manager import FunctionManager, ProcedureManager
Expand Down Expand Up @@ -239,14 +239,14 @@ def _assert_object_definitions_are_correct(

def _find_existing_objects(
object_type: ObjectType,
objects: List[Dict],
objects: List[FunctionSchema | ProcedureSchema],
om: ObjectManager,
):
existing_objects = {}
for object_definition in objects:
identifier = build_udf_sproc_identifier(
object_definition, om, include_parameter_names=False
)
identifier = UdfSprocIdentifier.from_definition(
object_definition
).identifier_with_arg_types
try:
current_state = om.describe(
object_type=object_type.value.sf_name,
Expand Down Expand Up @@ -300,28 +300,24 @@ def _deploy_single_object(
snowflake_dependencies: List[str],
stage_artifact_path: str,
):
identifier = build_udf_sproc_identifier(
object_definition, manager, include_parameter_names=False
)
identifier_with_default_values = build_udf_sproc_identifier(
object_definition,
manager,
include_parameter_names=True,
include_default_values=True,

identifiers = UdfSprocIdentifier.from_definition(object_definition)

log.info(
"Deploying %s: %s", object_type, identifiers.identifier_with_arg_names_types
)
log.info("Deploying %s: %s", object_type, identifier_with_default_values)

handler = object_definition.handler
returns = object_definition.returns
imports = object_definition.imports
external_access_integrations = object_definition.external_access_integrations
replace_object = False

object_exists = identifier in existing_objects
object_exists = identifiers.identifier_with_arg_types in existing_objects
if object_exists:
replace_object = check_if_replace_is_required(
object_type=object_type,
current_state=existing_objects[identifier],
current_state=existing_objects[identifiers.identifier_with_arg_types],
handler=handler,
return_type=returns,
snowflake_dependencies=snowflake_dependencies,
Expand All @@ -332,13 +328,13 @@ def _deploy_single_object(

if object_exists and not replace_object:
return {
"object": identifier_with_default_values,
"object": identifiers.identifier_with_arg_names_types_defaults,
"type": str(object_type),
"status": "packages updated",
}

create_or_replace_kwargs = {
"identifier": identifier_with_default_values,
"identifier": identifiers,
"handler": handler,
"return_type": returns,
"artifact_file": stage_artifact_path,
Expand All @@ -356,7 +352,7 @@ def _deploy_single_object(

status = "created" if not object_exists else "definition updated"
return {
"object": identifier_with_default_values,
"object": identifiers.identifier_with_arg_names_types_defaults,
"type": str(object_type),
"status": status,
}
Expand Down
94 changes: 68 additions & 26 deletions src/snowflake/cli/plugins/snowpark/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@

from snowflake.cli.api.constants import ObjectType
from snowflake.cli.api.identifiers import FQN
from snowflake.cli.api.project.schemas.snowpark.argument import Argument
from snowflake.cli.api.project.schemas.snowpark.callable import (
FunctionSchema,
ProcedureSchema,
)
from snowflake.cli.api.sql_execution import SqlExecutionMixin
from snowflake.cli.plugins.snowpark.models import Requirement
from snowflake.cli.plugins.snowpark.package_utils import (
Expand Down Expand Up @@ -150,7 +153,7 @@ def artifact_stage_path(identifier: str):

def create_query(
self,
identifier: str,
identifier: UdfSprocIdentifier,
return_type: str,
handler: str,
artifact_file: str,
Expand All @@ -166,7 +169,7 @@ def create_query(
packages_list = ",".join(f"'{p}'" for p in packages)

query = [
f"create or replace {self._object_type.value.sf_name} {identifier}",
f"create or replace {self._object_type.value.sf_name} {identifier.identifier_for_sql}",
f"copy grants",
f"returns {return_type}",
"language python",
Expand Down Expand Up @@ -198,30 +201,69 @@ def _is_signature_type_a_string(sig_type: str) -> bool:
return sig_type.lower() in ["string", "varchar"]


def build_udf_sproc_identifier(
udf_sproc,
slq_exec_mixin,
include_parameter_names,
include_default_values=False,
):
def format_arg(arg: Argument):
result = f"{arg.arg_type}"
if include_parameter_names:
result = f"{arg.name} {result}"
if include_default_values and arg.default:
val = f"{arg.default}"
if _is_signature_type_a_string(arg.arg_type):
val = f"'{val}'"
result += f" default {val}"
return result

if udf_sproc.signature and udf_sproc.signature != "null":
arguments = ", ".join(format_arg(arg) for arg in udf_sproc.signature)
else:
arguments = ""
class UdfSprocIdentifier:
def __init__(self, identifier: FQN, arg_names, arg_types, arg_defaults):
self._identifier = identifier
self._arg_names = arg_names
self._arg_types = arg_types
self._arg_defaults = arg_defaults

def _identifier_from_signature(self, sig: List[str], for_sql: bool = False):
signature = self._comma_join(sig)
id_ = self._identifier.sql_identifier if for_sql else self._identifier
return f"{id_}({signature})"

@staticmethod
def _comma_join(*args):
return ", ".join(*args)

@property
def identifier_with_arg_names(self):
return self._identifier_from_signature(self._arg_names)

name = FQN.from_identifier_model(udf_sproc).using_context().identifier
return f"{name}({arguments})"
@property
def identifier_with_arg_types(self):
return self._identifier_from_signature(self._arg_types)

@property
def identifier_with_arg_names_types(self):
sig = [f"{n} {t}" for n, t in zip(self._arg_names, self._arg_types)]
return self._identifier_from_signature(sig)

@property
def identifier_with_arg_names_types_defaults(self):
return self._identifier_from_signature(self._full_signature())

def _full_signature(self):
sig = []
for name, _type, _default in zip(
self._arg_names, self._arg_types, self._arg_defaults
):
s = f"{name} {_type}"
if _default:
if _is_signature_type_a_string(_type):
_default = f"'{_default}'"
s += f" default {_default}"
sig.append(s)
return sig

@property
def identifier_for_sql(self):
return self._identifier_from_signature(self._full_signature(), for_sql=True)

@classmethod
def from_definition(cls, udf_sproc: FunctionSchema | ProcedureSchema):
names = []
types = []
defaults = []
if udf_sproc.signature and udf_sproc.signature != "null":
for arg in udf_sproc.signature:
names.append(arg.name)
types.append(arg.arg_type)
defaults.append(arg.default)

identifier = FQN.from_identifier_model(udf_sproc).using_context()
return cls(identifier, names, types, defaults)


def _compare_imports(
Expand Down
21 changes: 16 additions & 5 deletions src/snowflake/cli/plugins/snowpark/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
from typing import Dict, List, Optional

from snowflake.cli.api.constants import ObjectType
from snowflake.cli.plugins.snowpark.common import SnowparkObjectManager
from snowflake.cli.plugins.snowpark.common import (
SnowparkObjectManager,
UdfSprocIdentifier,
)
from snowflake.connector.cursor import SnowflakeCursor

log = logging.getLogger(__name__)
Expand All @@ -35,7 +38,7 @@ def _object_execute(self):

def create_or_replace(
self,
identifier: str,
identifier: UdfSprocIdentifier,
return_type: str,
handler: str,
artifact_file: str,
Expand All @@ -45,7 +48,11 @@ def create_or_replace(
secrets: Optional[Dict[str, str]] = None,
runtime: Optional[str] = None,
) -> SnowflakeCursor:
log.debug("Creating function %s using @%s", identifier, artifact_file)
log.debug(
"Creating function %s using @%s",
identifier.identifier_with_arg_names_types_defaults,
artifact_file,
)
query = self.create_query(
identifier,
return_type,
Expand All @@ -71,7 +78,7 @@ def _object_execute(self):

def create_or_replace(
self,
identifier: str,
identifier: UdfSprocIdentifier,
return_type: str,
handler: str,
artifact_file: str,
Expand All @@ -82,7 +89,11 @@ def create_or_replace(
runtime: Optional[str] = None,
execute_as_caller: bool = False,
) -> SnowflakeCursor:
log.debug("Creating procedure %s using @%s", identifier, artifact_file)
log.debug(
"Creating procedure %s using @%s",
identifier.identifier_with_arg_names_types_defaults,
artifact_file,
)
query = self.create_query(
identifier,
return_type,
Expand Down
8 changes: 4 additions & 4 deletions tests/snowpark/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_deploy_function(
f" auto_compress=false parallel=4 overwrite=True",
dedent(
"""\
create or replace function MockDatabase.MockSchema.func1(a string default 'default value', b variant)
create or replace function IDENTIFIER('MockDatabase.MockSchema.func1')(a string default 'default value', b variant)
copy grants
returns string
language python
Expand Down Expand Up @@ -105,7 +105,7 @@ def test_deploy_function_with_external_access(
f" auto_compress=false parallel=4 overwrite=True",
dedent(
"""\
create or replace function MockDatabase.MockSchema.func1(a string, b variant)
create or replace function IDENTIFIER('MockDatabase.MockSchema.func1')(a string, b variant)
copy grants
returns string
language python
Expand Down Expand Up @@ -225,7 +225,7 @@ def test_deploy_function_needs_update_because_packages_changes(
f"put file://{Path(project_dir).resolve()}/app.zip @MockDatabase.MockSchema.dev_deployment/my_snowpark_project auto_compress=false parallel=4 overwrite=True",
dedent(
"""\
create or replace function MockDatabase.MockSchema.func1(a string default 'default value', b variant)
create or replace function IDENTIFIER('MockDatabase.MockSchema.func1')(a string default 'default value', b variant)
copy grants
returns string
language python
Expand Down Expand Up @@ -276,7 +276,7 @@ def test_deploy_function_needs_update_because_handler_changes(
f" auto_compress=false parallel=4 overwrite=True",
dedent(
"""\
create or replace function MockDatabase.MockSchema.func1(a string default 'default value', b variant)
create or replace function IDENTIFIER('MockDatabase.MockSchema.func1')(a string default 'default value', b variant)
copy grants
returns string
language python
Expand Down
6 changes: 3 additions & 3 deletions tests/snowpark/test_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_deploy_procedure(
f"put file://{Path(tmp).resolve()}/app.zip @MockDatabase.MockSchema.dev_deployment/my_snowpark_project auto_compress=false parallel=4 overwrite=True",
dedent(
"""\
create or replace procedure MockDatabase.MockSchema.procedureName(name string)
create or replace procedure IDENTIFIER('MockDatabase.MockSchema.procedureName')(name string)
copy grants
returns string
language python
Expand All @@ -91,7 +91,7 @@ def test_deploy_procedure(
).strip(),
dedent(
"""\
create or replace procedure MockDatabase.MockSchema.test()
create or replace procedure IDENTIFIER('MockDatabase.MockSchema.test')()
copy grants
returns string
language python
Expand Down Expand Up @@ -149,7 +149,7 @@ def test_deploy_procedure_with_external_access(
f" auto_compress=false parallel=4 overwrite=True",
dedent(
"""\
create or replace procedure MockDatabase.MockSchema.procedureName(name string)
create or replace procedure IDENTIFIER('MockDatabase.MockSchema.procedureName')(name string)
copy grants
returns string
language python
Expand Down

0 comments on commit c69dbcf

Please sign in to comment.