Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(mlop-2454): add protection to host setting on cassandra_client #384

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ style-check:
@echo "Code Style"
@echo "=========="
@echo ""
@python -m black --check -t py39 --exclude="build/|buck-out/|dist/|_build/|pip/|\.pip/|\.git/|\.hg/|\.mypy_cache/|\.tox/|\.venv/" . && echo "\n\nSuccess" || (echo "\n\nFailure\n\nYou need to run \"make apply-style\" to apply style formatting to your code"; exit 1)
@python -m black --check -t py39 --exclude="build/|buck-out/|dist/|_build/|pip/|\.pip/|\.git/|\.hg/|\.mypy_cache/|\.tox/|\.venv/|venv/" . && echo "\n\nSuccess" || (echo "\n\nFailure\n\nYou need to run \"make apply-style\" to apply style formatting to your code"; exit 1)

.PHONY: quality-check
## run code quality checks with flake8
Expand All @@ -85,7 +85,7 @@ quality-check:
@echo "Flake 8"
@echo "======="
@echo ""
@python -m flake8 && echo "Success"
@python -m flake8 --exclude="venv" && echo "Success"
@echo ""

.PHONY: type-check
Expand All @@ -95,7 +95,7 @@ type-check:
@echo "mypy"
@echo "===="
@echo ""
@python -m mypy butterfree
@python -m mypy --exclude="venv" butterfree

.PHONY: checks
## run all code checks
Expand All @@ -104,7 +104,7 @@ checks: style-check quality-check type-check
.PHONY: apply-style
## fix stylistic errors with black
apply-style:
@python -m black -t py39 --exclude="build/|buck-out/|dist/|_build/|pip/|\.pip/|\.git/|\.hg/|\.mypy_cache/|\.tox/|\.venv/" .
@python -m black -t py39 --exclude="build/|buck-out/|dist/|_build/|pip/|\.pip/|\.git/|\.hg/|\.mypy_cache/|\.tox/|\.venv/|venv/" .
@python -m isort --atomic butterfree/ tests/

.PHONY: clean
Expand Down
46 changes: 44 additions & 2 deletions butterfree/clients/cassandra_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""CassandraClient entity."""

from ssl import CERT_REQUIRED, PROTOCOL_TLSv1
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

from cassandra.auth import PlainTextAuthProvider
from cassandra.cluster import (
Expand All @@ -16,6 +16,12 @@
from typing_extensions import TypedDict

from butterfree.clients import AbstractClient
from butterfree.configs.logger import __logger

logger = __logger("cassandra_client")

EMPTY_STRING_HOST_ERROR = "The value of Cassandra host is empty. Please fill correctly with your endpoints" # noqa: E501
GENERIC_INVALID_HOST_ERROR = "The Cassandra host must be a valid string, a string that represents a list or list of strings" # noqa: E501


class CassandraColumn(TypedDict):
Expand Down Expand Up @@ -53,12 +59,48 @@ def __init__(
user: Optional[str] = None,
password: Optional[str] = None,
) -> None:
self.host = host
self.host = self._validate_and_format_cassandra_host(host)
logger.info(f"The host setted is {self.host}")
self.keyspace = keyspace
self.user = user
self.password = password
self._session: Optional[Session] = None

def _validate_and_format_cassandra_host(self, host: Union[List, str]):
"""
Validate and format the provided Cassandra host input.

This method checks if the input `host` is either a string, a list of strings, or
a list containing a single string with comma-separated values. It splits the string
by commas and trims whitespace, returning a list of hosts. If the input is already
a list of strings, it returns that list. If the input is empty or invalid, a
ValueError is raised.

Args:
host (str | list): The Cassandra host input, which can be a comma-separated
string or a list of string endpoints.

Returns:
list: A list of formatted Cassandra host strings.

Raises:
ValueError: If the input is an empty string or if it is not a string
(or a representation of a list) or a list of strings.
""" # noqa: E501
if isinstance(host, str):
if host:
return [item.strip() for item in host.split(",")]
else:
raise ValueError(EMPTY_STRING_HOST_ERROR)

if isinstance(host, list):
if len(host) == 1 and isinstance(host[0], str):
return [item.strip() for item in host[0].split(",")]
elif all(isinstance(item, str) for item in host):
return host

raise ValueError(GENERIC_INVALID_HOST_ERROR)

@property
def conn(self, *, ssl_path: str = None) -> Session: # type: ignore
"""Establishes a Cassandra connection."""
Expand Down
52 changes: 51 additions & 1 deletion tests/unit/butterfree/clients/test_cassandra_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from typing import Any, Dict, List
from unittest.mock import MagicMock

import pytest

from butterfree.clients import CassandraClient
from butterfree.clients.cassandra_client import CassandraColumn
from butterfree.clients.cassandra_client import (
EMPTY_STRING_HOST_ERROR,
GENERIC_INVALID_HOST_ERROR,
CassandraColumn,
)


def sanitize_string(query: str) -> str:
Expand Down Expand Up @@ -86,3 +92,47 @@ def test_cassandra_create_table(
query = cassandra_client.sql.call_args[0][0]

assert sanitize_string(query) == sanitize_string(expected_query)

def test_initialize_with_string_host(self):
client = CassandraClient(host="127.0.0.0, 127.0.0.1", keyspace="dummy_keyspace")
assert client.host == ["127.0.0.0", "127.0.0.1"]

def test_initialize_with_list_host(self):
client = CassandraClient(
host=["127.0.0.0", "127.0.0.1"], keyspace="test_keyspace"
)
assert client.host == ["127.0.0.0", "127.0.0.1"]

def test_initialize_with_empty_string_host(self):
with pytest.raises(
ValueError,
match=EMPTY_STRING_HOST_ERROR,
):
CassandraClient(host="", keyspace="test_keyspace")

def test_initialize_with_none_host(self):
with pytest.raises(
ValueError,
match=GENERIC_INVALID_HOST_ERROR,
):
CassandraClient(host=None, keyspace="test_keyspace")

def test_initialize_with_invalid_host_type(self):
with pytest.raises(
ValueError,
match=GENERIC_INVALID_HOST_ERROR,
):
CassandraClient(host=123, keyspace="test_keyspace")

def test_initialize_with_invalid_list_host(self):
with pytest.raises(
ValueError,
match=GENERIC_INVALID_HOST_ERROR,
):
CassandraClient(host=["127.0.0.0", 123], keyspace="test_keyspace")

def test_initialize_with_list_of_string_hosts(self):
client = CassandraClient(
host=["127.0.0.0, 127.0.0.1"], keyspace="test_keyspace"
)
assert client.host == ["127.0.0.0", "127.0.0.1"]
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def target_df_spark(spark_context, spark_session):
"timestamp": "2016-04-11 11:31:11",
"feature1": 200,
"feature2": 200,
"feature__cos": 0.48718767500700594,
"feature__cos": 0.4871876750070059,
},
{
"id": 1,
Expand Down
Loading