Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SnowTyper together with telemetry #731

Merged
merged 6 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions src/snowflake/cli/api/commands/snow_typer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from __future__ import annotations

import logging
from functools import wraps
from typing import Optional

import typer
from snowflake.cli.api.commands.decorators import (
global_options,
global_options_with_connection,
)
from snowflake.cli.api.commands.flags import DEFAULT_CONTEXT_SETTINGS
from snowflake.cli.api.exceptions import CommandReturnTypeError
from snowflake.cli.api.output.types import CommandResult
from snowflake.cli.app.printing import print_result
from snowflake.cli.app.telemetry import flush_telemetry, log_command_usage

log = logging.getLogger(__name__)


class SnowTyper(typer.Typer):
def __init__(self, /, **kwargs):
super().__init__(
**kwargs,
context_settings=DEFAULT_CONTEXT_SETTINGS,
pretty_exceptions_show_locals=False,
)

@wraps(typer.Typer.command)
def command(
self,
name: Optional[str] = None,
requires_global_options: bool = True,
requires_connection: bool = False,
**kwargs,
):
"""
Custom implementation of Typer.command that adds ability to execute additional
logic before and after execution as well as process the result and act on possible
errors.
"""

def custom_command(command_callable):
"""Custom command wrapper similar to Typer.command."""
if requires_connection:
command_callable = global_options_with_connection(command_callable)
elif requires_global_options:
command_callable = global_options(command_callable)

@wraps(command_callable)
def command_callable_decorator(*args, **kw):
"""Wrapper around command callable. This is what happens at "runtime"."""
self.pre_execute()
try:
result = command_callable(*args, **kw)
return self.process_result(result)
except Exception as err:
self.exception_handler(err)
raise
finally:
self.post_execute()

return super(SnowTyper, self).command(name=name, **kwargs)(
command_callable_decorator
)

return custom_command

@staticmethod
def pre_execute():
"""
Callback executed before running any command callable (after context execution).
Pay attention to make this method safe to use if performed operations are not necessary
for executing the command in proper way.
"""
log.debug("Executing command pre execution callback")
log_command_usage()

@staticmethod
def process_result(result):
"""Command result processor"""
if not isinstance(result, CommandResult):
raise CommandReturnTypeError(type(result))
print_result(result)

@staticmethod
def exception_handler(exception: Exception):
"""
Callback executed on command execution error.
"""
log.debug("Executing command exception callback")

@staticmethod
def post_execute():
"""
Callback executed after running any command callable. Pay attention to make this method safe to
use if performed operations are not necessary for executing the command in proper way.
"""
log.debug("Executing command post execution callback")
flush_telemetry()
9 changes: 9 additions & 0 deletions src/snowflake/cli/api/utils/error_handling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from contextlib import contextmanager


@contextmanager
def ignore_exceptions():
try:
yield
except:
pass
6 changes: 4 additions & 2 deletions src/snowflake/cli/app/main_typer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import sys

import typer
from rich import print as rich_print
from snowflake.cli.api.cli_global_context import cli_context
from snowflake.cli.api.commands.flags import DEFAULT_CONTEXT_SETTINGS, DebugOption
from typer import Typer


def _handle_exception(exception: Exception):
Expand All @@ -17,7 +19,7 @@ def _handle_exception(exception: Exception):
raise SystemExit(1)


class SnowCliMainTyper(Typer):
class SnowCliMainTyper(typer.Typer):
"""
Top-level SnowCLI Typer.
It contains global exception handling.
Expand Down
12 changes: 2 additions & 10 deletions src/snowflake/cli/app/snow_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
import os
from typing import Dict, Optional

import click
import snowflake.connector
from click.exceptions import ClickException
from snowflake.cli.api.config import get_connection, get_default_connection
from snowflake.cli.api.exceptions import (
InvalidConnectionConfiguration,
SnowflakeConnectionError,
)
from snowflake.cli.app.telemetry import command_info
from snowflake.connector import SnowflakeConnection
from snowflake.connector.errors import DatabaseError, ForbiddenError

Expand Down Expand Up @@ -47,7 +47,7 @@ def connect_to_snowflake(temporary_connection: bool = False, connection_name: Op
# for cases when external browser and json format are used.
with contextlib.redirect_stdout(None):
return snowflake.connector.connect(
application=_find_command_path(),
application=command_info(),
**connection_parameters,
)
except ForbiddenError as err:
Expand All @@ -69,14 +69,6 @@ def _update_connection_details_with_private_key(connection_parameters: Dict):
return connection_parameters


def _find_command_path():
ctx = click.get_current_context(silent=True)
if ctx:
# Example: SNOWCLI.WAREHOUSE.STATUS
return ".".join(["SNOWCLI", *ctx.command_path.split(" ")[1:]]).upper()
return "SNOWCLI"


def _load_pem_to_der(private_key_path: str) -> bytes:
"""
Given a private key file path (in PEM format), decode key data into DER
Expand Down
117 changes: 117 additions & 0 deletions src/snowflake/cli/app/telemetry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from __future__ import annotations

import platform
import sys
from enum import Enum, unique
from typing import Any, Dict, Union

import click
from snowflake.cli.__about__ import VERSION
from snowflake.cli.api.cli_global_context import cli_context
from snowflake.cli.api.output.formats import OutputFormat
from snowflake.cli.api.utils.error_handling import ignore_exceptions
from snowflake.connector.telemetry import (
TelemetryData,
TelemetryField,
)
from snowflake.connector.time_util import get_time_millis


@unique
class CLITelemetryField(Enum):
# Basic information
SOURCE = "source"
VERSION_CLI = "version_cli"
VERSION_PYTHON = "version_python"
VERSION_OS = "version_os"
# Command execution context
COMMAND = "command"
COMMAND_GROUP = "command_group"
COMMAND_FLAGS = "command_flags"
COMMAND_OUTPUT_TYPE = "command_output_type"
# Information
EVENT = "event"
ERROR_MSG = "error_msg"
ERROR_TYPE = "error_type"


class TelemetryEvent(Enum):
CMD_EXECUTION = "executing_command"
sfc-gh-pjob marked this conversation as resolved.
Show resolved Hide resolved


TelemetryDict = Dict[Union[CLITelemetryField, TelemetryField], Any]


def _find_command_info() -> TelemetryDict:
ctx = click.get_current_context()
command_path = ctx.command_path.split(" ")[1:]
return {
CLITelemetryField.COMMAND: command_path,
CLITelemetryField.COMMAND_GROUP: command_path[0],
CLITelemetryField.COMMAND_FLAGS: {
k: ctx.get_parameter_source(k).name # type: ignore[attr-defined]
for k, v in ctx.params.items()
if v # noqa
},
CLITelemetryField.COMMAND_OUTPUT_TYPE: ctx.params.get(
"format", OutputFormat.TABLE
).value,
}


def command_info() -> str:
info = _find_command_info()
return ("SNOWCLI." + ".".join(info[CLITelemetryField.COMMAND])).upper()


def python_version() -> str:
py_ver = sys.version_info
return f"{py_ver.major}.{py_ver.minor}.{py_ver.micro}"


class CLITelemetryClient:
def __init__(self, ctx):
self._ctx = ctx

@staticmethod
def generate_telemetry_data_dict(
telemetry_payload: TelemetryDict,
) -> Dict[str, Any]:
data = {
CLITelemetryField.SOURCE: "snowcli",
CLITelemetryField.VERSION_CLI: VERSION,
CLITelemetryField.VERSION_OS: platform.platform(),
CLITelemetryField.VERSION_PYTHON: python_version(),
**_find_command_info(),
**telemetry_payload,
}
# To map Enum to string, so we don't have to use .value every time
return {getattr(k, "value", k): v for k, v in data.items()} # type: ignore[arg-type]

@property
def _telemetry(self):
return self._ctx.connection._telemetry # noqa

def send(self, payload: TelemetryDict):
if self._telemetry:
message = self.generate_telemetry_data_dict(payload)
telemetry_data = TelemetryData.from_telemetry_data_dict(
from_dict=message, timestamp=get_time_millis()
)
self._telemetry.try_add_log_to_batch(telemetry_data)

def flush(self):
self._telemetry.send_batch()


_telemetry = CLITelemetryClient(ctx=cli_context)


@ignore_exceptions()
def log_command_usage():
_telemetry.send({TelemetryField.KEY_TYPE: TelemetryEvent.CMD_EXECUTION.value})


@ignore_exceptions()
def flush_telemetry():
_telemetry.flush()
15 changes: 4 additions & 11 deletions src/snowflake/cli/plugins/connection/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import typer
from click import ClickException
from click.types import StringParamType
from snowflake.cli.api.commands.decorators import global_options, with_output
from snowflake.cli.api.commands.flags import DEFAULT_CONTEXT_SETTINGS, ConnectionOption
from snowflake.cli.api.commands.flags import ConnectionOption
from snowflake.cli.api.commands.snow_typer import SnowTyper
from snowflake.cli.api.config import (
add_connection,
connection_exists,
Expand All @@ -20,8 +20,7 @@
)
from snowflake.connector.config_manager import CONFIG_MANAGER

app = typer.Typer(
context_settings=DEFAULT_CONTEXT_SETTINGS,
app = SnowTyper(
name="connection",
help="Manages connections to Snowflake.",
)
Expand All @@ -45,8 +44,6 @@ def _mask_password(connection_params: dict):


@app.command(name="list")
@with_output
@global_options
def list_connections(**options) -> CommandResult:
"""
Lists configured connections.
Expand All @@ -71,8 +68,6 @@ def callback(value: str):


@app.command()
@global_options
@with_output
def add(
connection_name: str = typer.Option(
None,
Expand Down Expand Up @@ -210,9 +205,7 @@ def add(
)


@app.command()
@global_options
@with_output
@app.command(requires_connection=False)
sfc-gh-mraba marked this conversation as resolved.
Show resolved Hide resolved
def test(connection: str = ConnectionOption, **options) -> CommandResult:
"""
Tests the connection to Snowflake.
Expand Down
11 changes: 3 additions & 8 deletions src/snowflake/cli/plugins/sql/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,15 @@
from typing import Optional

import typer
from snowflake.cli.api.commands.decorators import (
global_options_with_connection,
with_output,
)
from snowflake.cli.api.commands.snow_typer import SnowTyper
from snowflake.cli.api.output.types import CommandResult, MultipleResults, QueryResult
from snowflake.cli.plugins.sql.manager import SqlManager

# simple Typer with defaults because it won't become a command group as it contains only one command
app = typer.Typer()
app = SnowTyper()


@app.command(name="sql")
@with_output
@global_options_with_connection
@app.command(name="sql", requires_connection=True)
def execute_sql(
query: Optional[str] = typer.Option(
None,
Expand Down
Loading