From c4aa982273e737a14fb76474cbd5dce800490062 Mon Sep 17 00:00:00 2001 From: David Wang Date: Wed, 7 Feb 2024 09:17:47 -0800 Subject: [PATCH] SNOW-1011772: Suspend and resume commands for compute pools. (#740) * Creating validation functions for object names. * SNOW-1011772: Adding templates for suspend/resume * SNOW-1011772: Adding 'resume' and 'suspend' commands for 'spcs compute-pool'. * SNOW-1011772: Adding doc strings * SNOW-1011772: Documentation fixes * SNOW-1011772: Test fixes * SNOW-1011772: Update release notes * SNOW-1011772: Formatting * SNOW-1011772: Updating compute pool callback to only allow a single valid identifier --- RELEASE-NOTES.md | 1 + src/snowflake/cli/api/project/util.py | 17 ++- .../cli/plugins/spcs/compute_pool/commands.py | 44 +++++- .../cli/plugins/spcs/compute_pool/manager.py | 8 +- tests/object/__snapshots__/test_object.ambr | 97 +++++-------- tests/object/test_common.py | 8 +- tests/object/test_object.py | 41 +++--- tests/project/test_util.py | 28 +++- tests/spcs/test_compute_pool.py | 97 ++++++++++++- tests_integration/spcs/test_cp.py | 75 +++++----- .../spcs/test_image_repository.py | 1 - .../assertions/test_result_assertions.py | 21 +++ .../testing_utils/compute_pool_utils.py | 129 ++++++++++++++++++ 13 files changed, 430 insertions(+), 137 deletions(-) create mode 100644 tests_integration/testing_utils/compute_pool_utils.py diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 15663806ff..6b163dbfe5 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -7,6 +7,7 @@ * Introduced `snowflake.cli.api.console.cli_console` object with helper methods for intermediate output. * Added new convenience command `spcs image-registry url` to get the URL for your account image registry. * Added convenience function `spcs image-repository url `. +* Added `suspend` and `resume` commands for `spcs compute-pool`. ## Fixes and improvements * Restricted permissions of automatically created files diff --git a/src/snowflake/cli/api/project/util.py b/src/snowflake/cli/api/project/util.py index 71c2c98d9b..22196864d6 100644 --- a/src/snowflake/cli/api/project/util.py +++ b/src/snowflake/cli/api/project/util.py @@ -12,8 +12,9 @@ SINGLE_QUOTED_STRING_LITERAL_REGEX = r"'((?:\\.|''|[^'\n])+?)'" # See https://docs.snowflake.com/en/sql-reference/identifiers-syntax for identifier syntax -UNQUOTED_IDENTIFIER_REGEX = r"(^[a-zA-Z_])([a-zA-Z0-9_$]{0,254})" +UNQUOTED_IDENTIFIER_REGEX = r"([a-zA-Z_])([a-zA-Z0-9_$]{0,254})" QUOTED_IDENTIFIER_REGEX = r'"((""|[^"]){0,255})"' +VALID_IDENTIFIER_REGEX = f"(?:{UNQUOTED_IDENTIFIER_REGEX}|{QUOTED_IDENTIFIER_REGEX})" def clean_identifier(input_: str): @@ -47,6 +48,20 @@ def is_valid_identifier(identifier: str) -> bool: ) +def is_valid_object_name(name: str, max_depth=2) -> bool: + """ + Determines whether the given identifier is a valid object name in the form , ., or ... + Max_depth determines how many valid identifiers are allowed. For example, account level objects would have a max depth of 0 + because they cannot be qualified by a database or schema, just the single identifier. + """ + if max_depth < 0: + raise ValueError("max_depth must be non-negative") + pattern = ( + rf"{VALID_IDENTIFIER_REGEX}(?:\.{VALID_IDENTIFIER_REGEX}){{0,{max_depth}}}" + ) + return re.fullmatch(pattern, name) is not None + + def to_identifier(name: str) -> str: """ Converts a name to a valid Snowflake identifier. If the name is already a valid diff --git a/src/snowflake/cli/plugins/spcs/compute_pool/commands.py b/src/snowflake/cli/plugins/spcs/compute_pool/commands.py index ba4e5e0604..3c3d0c8a38 100644 --- a/src/snowflake/cli/plugins/spcs/compute_pool/commands.py +++ b/src/snowflake/cli/plugins/spcs/compute_pool/commands.py @@ -1,12 +1,14 @@ from typing import Optional import typer +from click import ClickException from snowflake.cli.api.commands.decorators import ( global_options_with_connection, with_output, ) from snowflake.cli.api.commands.flags import DEFAULT_CONTEXT_SETTINGS from snowflake.cli.api.output.types import CommandResult, SingleQueryResult +from snowflake.cli.api.project.util import is_valid_object_name from snowflake.cli.plugins.object.common import comment_option from snowflake.cli.plugins.spcs.common import validate_and_set_instances from snowflake.cli.plugins.spcs.compute_pool.manager import ComputePoolManager @@ -18,11 +20,25 @@ ) +def _compute_pool_name_callback(name: str) -> str: + """ + Verifies that compute pool name is a single valid identifier. + """ + if not is_valid_object_name(name, 0): + raise ClickException(f"{name} is not a valid compute pool name.") + return name + + +ComputePoolNameArgument = typer.Argument( + ..., help="Name of the compute pool.", callback=_compute_pool_name_callback +) + + @app.command() @with_output @global_options_with_connection def create( - name: str = typer.Argument(..., help="Name of the compute pool."), + name: str = ComputePoolNameArgument, min_nodes: int = typer.Option( 1, "--min-nodes", help="Minimum number of nodes for the compute pool." ), @@ -72,11 +88,29 @@ def create( @app.command("stop-all") @with_output @global_options_with_connection -def stop_all( - name: str = typer.Argument(..., help="Name of the compute pool."), **options -) -> CommandResult: +def stop_all(name: str = ComputePoolNameArgument, **options) -> CommandResult: """ - Stops a compute pool and deletes all services running on the pool. + Deletes all services running on the compute pool. """ cursor = ComputePoolManager().stop(pool_name=name) return SingleQueryResult(cursor) + + +@app.command() +@with_output +@global_options_with_connection +def suspend(name: str = ComputePoolNameArgument, **options) -> CommandResult: + """ + Suspends the compute pool by suspending all currently running services and then releasing compute pool nodes. + """ + return SingleQueryResult(ComputePoolManager().suspend(name)) + + +@app.command() +@with_output +@global_options_with_connection +def resume(name: str = ComputePoolNameArgument, **options) -> CommandResult: + """ + Resumes the compute pool from SUSPENDED state. + """ + return SingleQueryResult(ComputePoolManager().resume(name)) diff --git a/src/snowflake/cli/plugins/spcs/compute_pool/manager.py b/src/snowflake/cli/plugins/spcs/compute_pool/manager.py index fd6845180c..897cd0ba47 100644 --- a/src/snowflake/cli/plugins/spcs/compute_pool/manager.py +++ b/src/snowflake/cli/plugins/spcs/compute_pool/manager.py @@ -31,4 +31,10 @@ def create( return self._execute_query(strip_empty_lines(query)) def stop(self, pool_name: str) -> SnowflakeCursor: - return self._execute_query(f"alter compute pool {pool_name} stop all;") + return self._execute_query(f"alter compute pool {pool_name} stop all") + + def suspend(self, pool_name: str) -> SnowflakeCursor: + return self._execute_query(f"alter compute pool {pool_name} suspend") + + def resume(self, pool_name: str) -> SnowflakeCursor: + return self._execute_query(f"alter compute pool {pool_name} resume") diff --git a/tests/object/__snapshots__/test_object.ambr b/tests/object/__snapshots__/test_object.ambr index a2cb01cf10..a7d36a4732 100644 --- a/tests/object/__snapshots__/test_object.ambr +++ b/tests/object/__snapshots__/test_object.ambr @@ -1,5 +1,5 @@ # serializer version: 1 -# name: test_describe[compute-pool-compute-pool-example] +# name: test_describe[compute-pool-compute_pool_example] ''' SELECT A MOCK QUERY +-----------------------------+ @@ -11,7 +11,7 @@ ''' # --- -# name: test_describe[database-database-example] +# name: test_describe[database-database_example] ''' SELECT A MOCK QUERY +-----------------------------+ @@ -23,7 +23,7 @@ ''' # --- -# name: test_describe[function-function-example] +# name: test_describe[function-function_example] ''' SELECT A MOCK QUERY +-----------------------------+ @@ -35,7 +35,7 @@ ''' # --- -# name: test_describe[integration-integration] +# name: test_describe[integration-integration_example] ''' SELECT A MOCK QUERY +-----------------------------+ @@ -47,7 +47,7 @@ ''' # --- -# name: test_describe[network-rule-network rule] +# name: test_describe[network-rule-network_rule_example] ''' SELECT A MOCK QUERY +-----------------------------+ @@ -59,7 +59,7 @@ ''' # --- -# name: test_describe[network-rule-network-rule-example] +# name: test_describe[procedure-procedure_example] ''' SELECT A MOCK QUERY +-----------------------------+ @@ -71,7 +71,7 @@ ''' # --- -# name: test_describe[procedure-procedure-example] +# name: test_describe[role-role_example] ''' SELECT A MOCK QUERY +-----------------------------+ @@ -83,7 +83,7 @@ ''' # --- -# name: test_describe[role-role-example] +# name: test_describe[schema-schema_example] ''' SELECT A MOCK QUERY +-----------------------------+ @@ -95,7 +95,7 @@ ''' # --- -# name: test_describe[schema-schema-example] +# name: test_describe[secret-secret_example] ''' SELECT A MOCK QUERY +-----------------------------+ @@ -107,7 +107,7 @@ ''' # --- -# name: test_describe[secret-secret-example] +# name: test_describe[service-service_example] ''' SELECT A MOCK QUERY +-----------------------------+ @@ -119,7 +119,7 @@ ''' # --- -# name: test_describe[service-service-example] +# name: test_describe[stage-stage_example] ''' SELECT A MOCK QUERY +-----------------------------+ @@ -131,7 +131,7 @@ ''' # --- -# name: test_describe[stage-stage-example] +# name: test_describe[stream-stream_example] ''' SELECT A MOCK QUERY +-----------------------------+ @@ -143,7 +143,7 @@ ''' # --- -# name: test_describe[stream-stream-example] +# name: test_describe[streamlit-streamlit_example] ''' SELECT A MOCK QUERY +-----------------------------+ @@ -155,7 +155,7 @@ ''' # --- -# name: test_describe[streamlit-streamlit-example] +# name: test_describe[table-table_example] ''' SELECT A MOCK QUERY +-----------------------------+ @@ -167,7 +167,7 @@ ''' # --- -# name: test_describe[table-table-example] +# name: test_describe[task-task_example] ''' SELECT A MOCK QUERY +-----------------------------+ @@ -179,7 +179,7 @@ ''' # --- -# name: test_describe[task-task-example] +# name: test_describe[user-user_example] ''' SELECT A MOCK QUERY +-----------------------------+ @@ -191,7 +191,7 @@ ''' # --- -# name: test_describe[user-user-example] +# name: test_describe[view-view_example] ''' SELECT A MOCK QUERY +-----------------------------+ @@ -203,19 +203,7 @@ ''' # --- -# name: test_describe[view-view-example] - ''' - SELECT A MOCK QUERY - +-----------------------------+ - | name | type | kind | - |------+-------------+--------| - | ID | NUMBER(38,0 | COLUMN | - | NAME | VARCHAR(100 | COLUMN | - +-----------------------------+ - - ''' -# --- -# name: test_describe[warehouse-warehouse-example] +# name: test_describe[warehouse-warehouse_example] ''' SELECT A MOCK QUERY +-----------------------------+ @@ -235,7 +223,7 @@ ''' # --- -# name: test_drop[compute-pool-compute-pool-example] +# name: test_drop[compute-pool-compute_pool_example] ''' SELECT A MOCK QUERY +--------+ @@ -246,7 +234,7 @@ ''' # --- -# name: test_drop[database-database-example] +# name: test_drop[database-database_example] ''' SELECT A MOCK QUERY +--------+ @@ -257,7 +245,7 @@ ''' # --- -# name: test_drop[function-function-example] +# name: test_drop[function-function_example] ''' SELECT A MOCK QUERY +--------+ @@ -268,7 +256,7 @@ ''' # --- -# name: test_drop[image-repository-image-repository-example] +# name: test_drop[image-repository-image_repository_example] ''' SELECT A MOCK QUERY +--------+ @@ -279,7 +267,7 @@ ''' # --- -# name: test_drop[integration-integration] +# name: test_drop[integration-integration_example] ''' SELECT A MOCK QUERY +--------+ @@ -290,18 +278,7 @@ ''' # --- -# name: test_drop[network-rule-network rule] - ''' - SELECT A MOCK QUERY - +--------+ - | status | - |--------| - | n | - +--------+ - - ''' -# --- -# name: test_drop[network-rule-network-rule-example] +# name: test_drop[network-rule-network_rule_example] ''' SELECT A MOCK QUERY +--------+ @@ -312,7 +289,7 @@ ''' # --- -# name: test_drop[procedure-procedure-example] +# name: test_drop[procedure-procedure_example] ''' SELECT A MOCK QUERY +--------+ @@ -323,7 +300,7 @@ ''' # --- -# name: test_drop[role-role-example] +# name: test_drop[role-role_example] ''' SELECT A MOCK QUERY +--------+ @@ -334,7 +311,7 @@ ''' # --- -# name: test_drop[schema-schema-example] +# name: test_drop[schema-schema_example] ''' SELECT A MOCK QUERY +--------+ @@ -345,7 +322,7 @@ ''' # --- -# name: test_drop[secret-secret-example] +# name: test_drop[secret-secret_example] ''' SELECT A MOCK QUERY +--------+ @@ -356,7 +333,7 @@ ''' # --- -# name: test_drop[service-service-example] +# name: test_drop[service-service_example] ''' SELECT A MOCK QUERY +--------+ @@ -367,7 +344,7 @@ ''' # --- -# name: test_drop[stage-stage-example] +# name: test_drop[stage-stage_example] ''' SELECT A MOCK QUERY +--------+ @@ -378,7 +355,7 @@ ''' # --- -# name: test_drop[stream-stream-example] +# name: test_drop[stream-stream_example] ''' SELECT A MOCK QUERY +--------+ @@ -389,7 +366,7 @@ ''' # --- -# name: test_drop[streamlit-streamlit-example] +# name: test_drop[streamlit-streamlit_example] ''' SELECT A MOCK QUERY +--------+ @@ -400,7 +377,7 @@ ''' # --- -# name: test_drop[table-table-example] +# name: test_drop[table-table_example] ''' SELECT A MOCK QUERY +--------+ @@ -411,7 +388,7 @@ ''' # --- -# name: test_drop[task-task-example] +# name: test_drop[task-task_example] ''' SELECT A MOCK QUERY +--------+ @@ -422,7 +399,7 @@ ''' # --- -# name: test_drop[user-user-example] +# name: test_drop[user-user_example] ''' SELECT A MOCK QUERY +--------+ @@ -433,7 +410,7 @@ ''' # --- -# name: test_drop[view-view-example] +# name: test_drop[view-view_example] ''' SELECT A MOCK QUERY +--------+ @@ -444,7 +421,7 @@ ''' # --- -# name: test_drop[warehouse-warehouse-example] +# name: test_drop[warehouse-warehouse_example] ''' SELECT A MOCK QUERY +--------+ diff --git a/tests/object/test_common.py b/tests/object/test_common.py index 1495ea2c31..41c487a6ec 100644 --- a/tests/object/test_common.py +++ b/tests/object/test_common.py @@ -1,9 +1,11 @@ -from snowflake.cli.plugins.object.common import _parse_tag, Tag, TagError +from snowflake.cli.plugins.object.common import ( + _parse_tag, + Tag, + TagError, +) from typing import Tuple import pytest -from click import ClickException - @pytest.mark.parametrize( "value, expected", diff --git a/tests/object/test_object.py b/tests/object/test_object.py index 12f27a7fb0..d4bfcbb8a0 100644 --- a/tests/object/test_object.py +++ b/tests/object/test_object.py @@ -43,26 +43,25 @@ def test_show( DESCRIBE_TEST_OBJECTS = [ - ("compute-pool", "compute-pool-example"), - ("network-rule", "network-rule-example"), - ("integration", "integration"), - ("network-rule", "network rule"), - ("database", "database-example"), - ("function", "function-example"), - # ("job", "job-example"), - ("procedure", "procedure-example"), - ("role", "role-example"), - ("schema", "schema-example"), - ("service", "service-example"), - ("secret", "secret-example"), - ("stage", "stage-example"), - ("stream", "stream-example"), - ("streamlit", "streamlit-example"), - ("table", "table-example"), - ("task", "task-example"), - ("user", "user-example"), - ("warehouse", "warehouse-example"), - ("view", "view-example"), + ("compute-pool", "compute_pool_example"), + ("network-rule", "network_rule_example"), + ("integration", "integration_example"), + ("database", "database_example"), + ("function", "function_example"), + # ("job", "job_example"), + ("procedure", "procedure_example"), + ("role", "role_example"), + ("schema", "schema_example"), + ("service", "service_example"), + ("secret", "secret_example"), + ("stage", "stage_example"), + ("stream", "stream_example"), + ("streamlit", "streamlit_example"), + ("table", "table_example"), + ("task", "task_example"), + ("user", "user_example"), + ("warehouse", "warehouse_example"), + ("view", "view_example"), ] @@ -182,7 +181,7 @@ def test_describe_fails_image_repository(mock_cursor, runner, snapshot): DROP_TEST_OBJECTS = [ *DESCRIBE_TEST_OBJECTS, - ("image-repository", "image-repository-example"), + ("image-repository", "image_repository_example"), ] diff --git a/tests/project/test_util.py b/tests/project/test_util.py index 7e8cc30aa8..f5025e8492 100644 --- a/tests/project/test_util.py +++ b/tests/project/test_util.py @@ -7,9 +7,12 @@ is_valid_unquoted_identifier, to_identifier, to_string_literal, + is_valid_object_name, escape_like_pattern, ) +from itertools import permutations + VALID_UNQUOTED_IDENTIFIERS = ( "_", "____", @@ -45,7 +48,7 @@ INVALID_QUOTED_IDENTIFIERS = ( '"abc', # unterminated quote 'abc"', # missing leading quote - '"abc"def"', # improprely escaped inner quote + '"abc"def"', # improperly escaped inner quote ) @@ -91,6 +94,29 @@ def test_is_valid_identifier(): assert not is_valid_identifier(id) +def test_is_valid_object_name(): + valid_identifiers = VALID_QUOTED_IDENTIFIERS + VALID_UNQUOTED_IDENTIFIERS + invalid_identifiers = INVALID_QUOTED_IDENTIFIERS + INVALID_QUOTED_IDENTIFIERS + + # any combination of 1, 2, or 3 valid identifiers separated by a '.' is valid + for num in [1, 2, 3]: + names = [".".join(p) for p in list(permutations(valid_identifiers, num))] + for name in names: + assert is_valid_object_name(name) + if num > 1: + assert not is_valid_object_name(name, 0) + + # any combination with at least one invalid identifier is invalid + for num in [1, 2, 3]: + valid_permutations = list(permutations(valid_identifiers, num - 1)) + for invalid_identifier in invalid_identifiers: + for valid_perm in valid_permutations: + combined_set = (invalid_identifier, *valid_perm) + names = [".".join(p) for p in list(permutations(combined_set, num))] + for name in names: + assert not is_valid_object_name(name) + + def test_to_identifier(): for id in VALID_UNQUOTED_IDENTIFIERS: assert to_identifier(id) == id diff --git a/tests/spcs/test_compute_pool.py b/tests/spcs/test_compute_pool.py index d6ce8efd0d..80c06aba97 100644 --- a/tests/spcs/test_compute_pool.py +++ b/tests/spcs/test_compute_pool.py @@ -1,8 +1,13 @@ from unittest.mock import Mock, patch from snowflake.cli.plugins.spcs.compute_pool.manager import ComputePoolManager +from snowflake.cli.plugins.spcs.compute_pool.commands import _compute_pool_name_callback from snowflake.connector.cursor import SnowflakeCursor from snowflake.cli.api.project.util import to_string_literal +import json +import pytest + +from click import ClickException @patch( @@ -114,6 +119,96 @@ def test_stop(mock_execute_query): cursor = Mock(spec=SnowflakeCursor) mock_execute_query.return_value = cursor result = ComputePoolManager().stop(pool_name) - expected_query = "alter compute pool test_pool stop all;" + expected_query = "alter compute pool test_pool stop all" + mock_execute_query.assert_called_once_with(expected_query) + assert result == cursor + + +@patch( + "snowflake.cli.plugins.spcs.compute_pool.manager.ComputePoolManager._execute_query" +) +def test_suspend(mock_execute_query): + pool_name = "test_pool" + cursor = Mock(spec=SnowflakeCursor) + mock_execute_query.return_value = cursor + result = ComputePoolManager().suspend(pool_name) + expected_query = "alter compute pool test_pool suspend" mock_execute_query.assert_called_once_with(expected_query) assert result == cursor + + +@patch("snowflake.cli.plugins.spcs.compute_pool.manager.ComputePoolManager.suspend") +def test_suspend_cli(mock_suspend, mock_cursor, runner): + pool_name = "test_pool" + cursor = mock_cursor( + rows=[["Statement executed successfully."]], columns=["status"] + ) + mock_suspend.return_value = cursor + result = runner.invoke(["spcs", "compute-pool", "suspend", pool_name]) + mock_suspend.assert_called_once_with(pool_name) + assert result.exit_code == 0, result.output + assert "Statement executed successfully" in result.output + + cursor_copy = mock_cursor( + rows=[["Statement executed successfully."]], columns=["status"] + ) + mock_suspend.return_value = cursor_copy + result_json = runner.invoke( + ["spcs", "compute-pool", "suspend", pool_name, "--format", "json"] + ) + result_json_parsed = json.loads(result_json.output) + assert isinstance(result_json_parsed, dict) + assert result_json_parsed == {"status": "Statement executed successfully."} + + +@patch( + "snowflake.cli.plugins.spcs.compute_pool.manager.ComputePoolManager._execute_query" +) +def test_resume(mock_execute_query): + pool_name = "test_pool" + cursor = Mock(spec=SnowflakeCursor) + mock_execute_query.return_value = cursor + result = ComputePoolManager().resume(pool_name) + expected_query = "alter compute pool test_pool resume" + mock_execute_query.assert_called_once_with(expected_query) + assert result == cursor + + +@patch("snowflake.cli.plugins.spcs.compute_pool.manager.ComputePoolManager.resume") +def test_resume_cli(mock_resume, mock_cursor, runner): + pool_name = "test_pool" + cursor = mock_cursor( + rows=[["Statement executed successfully."]], columns=["status"] + ) + mock_resume.return_value = cursor + result = runner.invoke(["spcs", "compute-pool", "resume", pool_name]) + mock_resume.assert_called_once_with(pool_name) + assert result.exit_code == 0, result.output + assert "Statement executed successfully" in result.output + + cursor_copy = mock_cursor( + rows=[["Statement executed successfully."]], columns=["status"] + ) + mock_resume.return_value = cursor_copy + result_json = runner.invoke( + ["spcs", "compute-pool", "resume", pool_name, "--format", "json"] + ) + result_json_parsed = json.loads(result_json.output) + assert isinstance(result_json_parsed, dict) + assert result_json_parsed == {"status": "Statement executed successfully."} + + +@patch("snowflake.cli.plugins.spcs.compute_pool.commands.is_valid_object_name") +def test_compute_pool_name_callback(mock_is_valid): + name = "test_pool" + mock_is_valid.return_value = True + assert _compute_pool_name_callback(name) == name + + +@patch("snowflake.cli.plugins.spcs.compute_pool.commands.is_valid_object_name") +def test_compute_pool_name_callback_invalid(mock_is_valid): + name = "test_pool" + mock_is_valid.return_value = False + with pytest.raises(ClickException) as e: + _compute_pool_name_callback(name) + assert "is not a valid compute pool name." in e.value.message diff --git a/tests_integration/spcs/test_cp.py b/tests_integration/spcs/test_cp.py index f727a7ac67..16c972536e 100644 --- a/tests_integration/spcs/test_cp.py +++ b/tests_integration/spcs/test_cp.py @@ -1,56 +1,45 @@ -import time +import uuid +from typing import Tuple import pytest -from tests_integration.test_utils import ( - contains_row_with, - not_contains_row_with, - row_from_snowflake_session, +from tests_integration.testing_utils.compute_pool_utils import ( + ComputePoolTestSetup, + ComputePoolTestSteps, ) @pytest.mark.integration -def test_cp(runner, snowflake_session): - cp_name = f"test_compute_pool_snowcli_{int(time.time())}" - - result = runner.invoke_with_connection_json( - [ - "spcs", - "compute-pool", - "create", - cp_name, - "--min-nodes", - 1, - "--family", - "STANDARD_1", - ] - ) - assert result.json, result.output - assert "status" in result.json - assert ( - f"Compute Pool {cp_name.upper()} successfully created." in result.json["status"] - ) +def test_cp(_test_steps: Tuple[ComputePoolTestSteps, str]): - expect = snowflake_session.execute_string(f"show compute pools like '{cp_name}'") - result = runner.invoke_with_connection_json(["object", "list", "compute-pool"]) + test_steps, compute_pool_name = _test_steps - assert result.json, result.output - assert contains_row_with(result.json, row_from_snowflake_session(expect)[0]) + test_steps.create_compute_pool(compute_pool_name) + test_steps.list_should_return_compute_pool(compute_pool_name) + test_steps.stop_all_on_compute_pool(compute_pool_name) + test_steps.suspend_compute_pool(compute_pool_name) + test_steps.wait_until_compute_pool_is_suspended(compute_pool_name) + test_steps.resume_compute_pool(compute_pool_name) + test_steps.wait_until_compute_pool_is_idle(compute_pool_name) + test_steps.drop_compute_pool(compute_pool_name) + test_steps.list_should_not_return_compute_pool(compute_pool_name) - result = runner.invoke_with_connection_json( - ["spcs", "compute-pool", "stop-all", cp_name] - ) - assert contains_row_with( - result.json, - {"status": "Statement executed successfully."}, - ) - result = runner.invoke_with_connection_json( - ["object", "drop", "compute-pool", cp_name] +@pytest.fixture +def _test_setup(runner, snowflake_session): + compute_pool_test_setup = ComputePoolTestSetup( + runner=runner, snowflake_session=snowflake_session ) - assert contains_row_with( - result.json, - {"status": f"{cp_name.upper()} successfully dropped."}, + yield compute_pool_test_setup + + +@pytest.fixture +def _test_steps(_test_setup): + compute_pool_name = f"compute_pool_{uuid.uuid4().hex}" + test_steps = ComputePoolTestSteps(_test_setup) + + yield test_steps, compute_pool_name + + _test_setup.snowflake_session.execute_string( + f"drop compute pool if exists {compute_pool_name}" ) - expect = snowflake_session.execute_string(f"show compute pools like '{cp_name}'") - assert not_contains_row_with(row_from_snowflake_session(expect), {"name": cp_name}) diff --git a/tests_integration/spcs/test_image_repository.py b/tests_integration/spcs/test_image_repository.py index 0d71f5a36c..580f8c8b36 100644 --- a/tests_integration/spcs/test_image_repository.py +++ b/tests_integration/spcs/test_image_repository.py @@ -29,7 +29,6 @@ def _list_images(runner): INTEGRATION_SCHEMA, ] ) - # breakpoint() assert isinstance(result.json, list), result.output assert contains_row_with( result.json, diff --git a/tests_integration/testing_utils/assertions/test_result_assertions.py b/tests_integration/testing_utils/assertions/test_result_assertions.py index c1a38fb42a..58a51d9dbd 100644 --- a/tests_integration/testing_utils/assertions/test_result_assertions.py +++ b/tests_integration/testing_utils/assertions/test_result_assertions.py @@ -40,3 +40,24 @@ def assert_that_result_is_successful_and_done_is_on_output( assert_that_result_is_successful(result) assert result.output is not None assert json.loads(result.output) == {"message": "Done"} + + +def assert_that_result_is_successful_and_executed_successfully( + result: CommandResult, is_json: bool = False +) -> None: + """ + Checks that the command result is {"status": "Statement executed successfully"} as either json or text output. + """ + assert_that_result_is_successful(result) + if is_json: + success_message = {"status": "Statement executed successfully."} + assert result.json is not None + if isinstance(result.json, dict): + assert result.json == success_message + else: + assert len(result.json) == 1 + assert result.json[0] == success_message + else: + assert result.output is not None + assert "status" in result.output + assert "Statement executed successfully" in result.output diff --git a/tests_integration/testing_utils/compute_pool_utils.py b/tests_integration/testing_utils/compute_pool_utils.py new file mode 100644 index 0000000000..4f73f3dbaf --- /dev/null +++ b/tests_integration/testing_utils/compute_pool_utils.py @@ -0,0 +1,129 @@ +import math +import time + +import pytest +from snowflake.connector import SnowflakeConnection + +from tests_integration.conftest import SnowCLIRunner +from tests_integration.test_utils import contains_row_with, not_contains_row_with +from tests_integration.testing_utils.assertions.test_result_assertions import ( + assert_that_result_is_successful_and_executed_successfully, +) + + +class ComputePoolTestSetup: + def __init__( + self, + runner: SnowCLIRunner, + snowflake_session: SnowflakeConnection, + ): + self.runner = runner + self.snowflake_session = snowflake_session + + +class ComputePoolTestSteps: + def __init__(self, setup: ComputePoolTestSetup): + self._setup = setup + + def create_compute_pool(self, compute_pool_name: str) -> None: + result = self._setup.runner.invoke_with_connection_json( + [ + "spcs", + "compute-pool", + "create", + compute_pool_name, + "--min-nodes", + 1, + "--family", + "CPU_X64_XS", + ] + ) + assert result.json, result.output + assert "status" in result.json + assert ( + f"Compute Pool {compute_pool_name.upper()} successfully created." + in result.json["status"] # type: ignore + ) + + def list_should_return_compute_pool(self, compute_pool_name) -> None: + result = self._execute_list() + assert contains_row_with(result.json, {"name": compute_pool_name.upper()}) + + def list_should_not_return_compute_pool(self, compute_pool_name: str) -> None: + result = self._execute_list() + assert not_contains_row_with(result.json, {"name": compute_pool_name.upper()}) + + def describe_should_return_compute_pool(self, compute_pool_name: str) -> None: + result = self._execute_describe(compute_pool_name) + assert result.json + assert len(result.json) == 1 + assert result.json[0]["name"] == compute_pool_name.upper() + + def drop_compute_pool(self, compute_pool_name: str) -> None: + result = self._setup.runner.invoke_with_connection_json( + [ + "object", + "drop", + "compute-pool", + compute_pool_name, + ], + ) + assert result.json + assert len(result.json) == 1 + assert result.json[0] == { # type: ignore + "status": f"{compute_pool_name.upper()} successfully dropped." + } + + def stop_all_on_compute_pool(self, compute_pool_name: str) -> None: + result = self._setup.runner.invoke_with_connection_json( + ["spcs", "compute-pool", "stop-all", compute_pool_name] + ) + assert_that_result_is_successful_and_executed_successfully(result, is_json=True) + + def suspend_compute_pool(self, compute_pool_name: str) -> None: + result = self._setup.runner.invoke_with_connection_json( + ["spcs", "compute-pool", "suspend", compute_pool_name] + ) + assert_that_result_is_successful_and_executed_successfully(result, is_json=True) + + def resume_compute_pool(self, compute_pool_name: str) -> None: + result = self._setup.runner.invoke_with_connection_json( + ["spcs", "compute-pool", "resume", compute_pool_name] + ) + assert_that_result_is_successful_and_executed_successfully(result, is_json=True) + + def wait_until_compute_pool_is_idle(self, compute_pool_name: str) -> None: + self._wait_until_compute_pool_reaches_state(compute_pool_name, "IDLE", 300) + + def wait_until_compute_pool_is_suspended(self, compute_pool_name: str) -> None: + self._wait_until_compute_pool_reaches_state(compute_pool_name, "SUSPENDED", 60) + + def _wait_until_compute_pool_reaches_state( + self, compute_pool_name: str, target_state: str, max_duration: int + ): + assert max_duration > 0 + max_counter = math.ceil(max_duration / 10) + target_state = target_state.upper() + for i in range(max_counter): + status = self._execute_describe(compute_pool_name) + if contains_row_with(status.json, {"state": target_state}): + return + time.sleep(10) + status = self._execute_describe(compute_pool_name) + + error_message = f"Compute pool {compute_pool_name} didn't reach target state '{target_state}' in {max_duration} seconds. Current state is '{status.json['state']}'" + pytest.fail(error_message) + + def _execute_describe(self, compute_pool_name: str): + return self._setup.runner.invoke_with_connection_json( + ["object", "describe", "compute-pool", compute_pool_name] + ) + + def _execute_list(self): + return self._setup.runner.invoke_with_connection_json( + [ + "object", + "list", + "compute-pool", + ], + )