-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add SnowTyper together with telemetry (#731)
- Loading branch information
1 parent
d39ac23
commit 23b2c2f
Showing
17 changed files
with
656 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from functools import wraps | ||
from typing import Optional | ||
|
||
import typer | ||
from snowflake.cli.api.commands.decorators import ( | ||
global_options, | ||
global_options_with_connection, | ||
) | ||
from snowflake.cli.api.commands.flags import DEFAULT_CONTEXT_SETTINGS | ||
from snowflake.cli.api.exceptions import CommandReturnTypeError | ||
from snowflake.cli.api.output.types import CommandResult | ||
from snowflake.cli.app.printing import print_result | ||
from snowflake.cli.app.telemetry import flush_telemetry, log_command_usage | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class SnowTyper(typer.Typer): | ||
def __init__(self, /, **kwargs): | ||
super().__init__( | ||
**kwargs, | ||
context_settings=DEFAULT_CONTEXT_SETTINGS, | ||
pretty_exceptions_show_locals=False, | ||
) | ||
|
||
@wraps(typer.Typer.command) | ||
def command( | ||
self, | ||
name: Optional[str] = None, | ||
requires_global_options: bool = True, | ||
requires_connection: bool = False, | ||
**kwargs, | ||
): | ||
""" | ||
Custom implementation of Typer.command that adds ability to execute additional | ||
logic before and after execution as well as process the result and act on possible | ||
errors. | ||
""" | ||
|
||
def custom_command(command_callable): | ||
"""Custom command wrapper similar to Typer.command.""" | ||
if requires_connection: | ||
command_callable = global_options_with_connection(command_callable) | ||
elif requires_global_options: | ||
command_callable = global_options(command_callable) | ||
|
||
@wraps(command_callable) | ||
def command_callable_decorator(*args, **kw): | ||
"""Wrapper around command callable. This is what happens at "runtime".""" | ||
self.pre_execute() | ||
try: | ||
result = command_callable(*args, **kw) | ||
return self.process_result(result) | ||
except Exception as err: | ||
self.exception_handler(err) | ||
raise | ||
finally: | ||
self.post_execute() | ||
|
||
return super(SnowTyper, self).command(name=name, **kwargs)( | ||
command_callable_decorator | ||
) | ||
|
||
return custom_command | ||
|
||
@staticmethod | ||
def pre_execute(): | ||
""" | ||
Callback executed before running any command callable (after context execution). | ||
Pay attention to make this method safe to use if performed operations are not necessary | ||
for executing the command in proper way. | ||
""" | ||
log.debug("Executing command pre execution callback") | ||
log_command_usage() | ||
|
||
@staticmethod | ||
def process_result(result): | ||
"""Command result processor""" | ||
if not isinstance(result, CommandResult): | ||
raise CommandReturnTypeError(type(result)) | ||
print_result(result) | ||
|
||
@staticmethod | ||
def exception_handler(exception: Exception): | ||
""" | ||
Callback executed on command execution error. | ||
""" | ||
log.debug("Executing command exception callback") | ||
|
||
@staticmethod | ||
def post_execute(): | ||
""" | ||
Callback executed after running any command callable. Pay attention to make this method safe to | ||
use if performed operations are not necessary for executing the command in proper way. | ||
""" | ||
log.debug("Executing command post execution callback") | ||
flush_telemetry() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from contextlib import contextmanager | ||
|
||
|
||
@contextmanager | ||
def ignore_exceptions(): | ||
try: | ||
yield | ||
except: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
from __future__ import annotations | ||
|
||
import platform | ||
import sys | ||
from enum import Enum, unique | ||
from typing import Any, Dict, Union | ||
|
||
import click | ||
from snowflake.cli.__about__ import VERSION | ||
from snowflake.cli.api.cli_global_context import cli_context | ||
from snowflake.cli.api.output.formats import OutputFormat | ||
from snowflake.cli.api.utils.error_handling import ignore_exceptions | ||
from snowflake.connector.telemetry import ( | ||
TelemetryData, | ||
TelemetryField, | ||
) | ||
from snowflake.connector.time_util import get_time_millis | ||
|
||
|
||
@unique | ||
class CLITelemetryField(Enum): | ||
# Basic information | ||
SOURCE = "source" | ||
VERSION_CLI = "version_cli" | ||
VERSION_PYTHON = "version_python" | ||
VERSION_OS = "version_os" | ||
# Command execution context | ||
COMMAND = "command" | ||
COMMAND_GROUP = "command_group" | ||
COMMAND_FLAGS = "command_flags" | ||
COMMAND_OUTPUT_TYPE = "command_output_type" | ||
# Information | ||
EVENT = "event" | ||
ERROR_MSG = "error_msg" | ||
ERROR_TYPE = "error_type" | ||
|
||
|
||
class TelemetryEvent(Enum): | ||
CMD_EXECUTION = "executing_command" | ||
|
||
|
||
TelemetryDict = Dict[Union[CLITelemetryField, TelemetryField], Any] | ||
|
||
|
||
def _find_command_info() -> TelemetryDict: | ||
ctx = click.get_current_context() | ||
command_path = ctx.command_path.split(" ")[1:] | ||
return { | ||
CLITelemetryField.COMMAND: command_path, | ||
CLITelemetryField.COMMAND_GROUP: command_path[0], | ||
CLITelemetryField.COMMAND_FLAGS: { | ||
k: ctx.get_parameter_source(k).name # type: ignore[attr-defined] | ||
for k, v in ctx.params.items() | ||
if v # noqa | ||
}, | ||
CLITelemetryField.COMMAND_OUTPUT_TYPE: ctx.params.get( | ||
"format", OutputFormat.TABLE | ||
).value, | ||
} | ||
|
||
|
||
def command_info() -> str: | ||
info = _find_command_info() | ||
return ("SNOWCLI." + ".".join(info[CLITelemetryField.COMMAND])).upper() | ||
|
||
|
||
def python_version() -> str: | ||
py_ver = sys.version_info | ||
return f"{py_ver.major}.{py_ver.minor}.{py_ver.micro}" | ||
|
||
|
||
class CLITelemetryClient: | ||
def __init__(self, ctx): | ||
self._ctx = ctx | ||
|
||
@staticmethod | ||
def generate_telemetry_data_dict( | ||
telemetry_payload: TelemetryDict, | ||
) -> Dict[str, Any]: | ||
data = { | ||
CLITelemetryField.SOURCE: "snowcli", | ||
CLITelemetryField.VERSION_CLI: VERSION, | ||
CLITelemetryField.VERSION_OS: platform.platform(), | ||
CLITelemetryField.VERSION_PYTHON: python_version(), | ||
**_find_command_info(), | ||
**telemetry_payload, | ||
} | ||
# To map Enum to string, so we don't have to use .value every time | ||
return {getattr(k, "value", k): v for k, v in data.items()} # type: ignore[arg-type] | ||
|
||
@property | ||
def _telemetry(self): | ||
return self._ctx.connection._telemetry # noqa | ||
|
||
def send(self, payload: TelemetryDict): | ||
if self._telemetry: | ||
message = self.generate_telemetry_data_dict(payload) | ||
telemetry_data = TelemetryData.from_telemetry_data_dict( | ||
from_dict=message, timestamp=get_time_millis() | ||
) | ||
self._telemetry.try_add_log_to_batch(telemetry_data) | ||
|
||
def flush(self): | ||
self._telemetry.send_batch() | ||
|
||
|
||
_telemetry = CLITelemetryClient(ctx=cli_context) | ||
|
||
|
||
@ignore_exceptions() | ||
def log_command_usage(): | ||
_telemetry.send({TelemetryField.KEY_TYPE: TelemetryEvent.CMD_EXECUTION.value}) | ||
|
||
|
||
@ignore_exceptions() | ||
def flush_telemetry(): | ||
_telemetry.flush() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.