Skip to content

Commit

Permalink
🎨 Ability to fetch size from query
Browse files Browse the repository at this point in the history
  • Loading branch information
nikhilbadyal committed Aug 21, 2024
1 parent a7db509 commit 1821b1c
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 23 deletions.
5 changes: 4 additions & 1 deletion esxport/click_opt/cli_options.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""CLII options."""
from __future__ import annotations

import ast
import json
from typing import Any

Expand Down Expand Up @@ -61,7 +62,9 @@ def __init__(self: Self, myclass_kwargs: dict[str, Any]) -> None:
self.fields: list[str] = list(self.fields)
self.index_prefixes: list[str] = list(self.index_prefixes)
self.meta_fields: list[str] = list(self.meta_fields)
self.max_results = int(self.max_results)
if isinstance(self.query, str):
self.query = ast.literal_eval(self.query)
self.max_results = self.query["size"] if self.query.get("size") else int(self.max_results)
self.scroll_size = int(self.scroll_size)
self.export_format: str = "csv"

Expand Down
52 changes: 32 additions & 20 deletions esxport/esxport.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,19 @@
FieldNotFoundError,
HealthCheckError,
IndexNotFoundError,
InvalidEsQueryError,
MetaFieldNotFoundError,
ScrollExpiredError,
)
from .strings import index_not_found, meta_field_not_found, output_fields, sorting_by, using_indexes, using_query
from .strings import (
index_not_found,
meta_field_not_found,
output_fields,
query_key_missing,
sorting_by,
using_indexes,
using_query,
)
from .writer import Writer

if TYPE_CHECKING:
Expand Down Expand Up @@ -99,25 +108,28 @@ def _validate_fields(self: Self) -> None:

def _prepare_search_query(self: Self) -> None:
"""Prepares search query from input."""
self.search_args = {
"index": ",".join(self.opts.index_prefixes),
"scroll": self.scroll_time,
"size": self.opts.scroll_size,
"terminate_after": self.opts.max_results,
"query": Json().convert(self.opts.query, None, None)["query"],
}
if self.opts.sort:
self.search_args["sort"] = self.opts.sort

if "_all" not in self.opts.fields:
self.search_args["_source_includes"] = ",".join(self.opts.fields)

if self.opts.debug:
logger.debug(using_indexes.format(indexes={", ".join(self.opts.index_prefixes)}))
query = json.dumps(self.opts.query, default=str)
logger.debug(using_query.format(query={query}))
logger.debug(output_fields.format(fields={", ".join(self.opts.fields)}))
logger.debug(sorting_by.format(sort=self.opts.sort))
try:
self.search_args = {
"index": ",".join(self.opts.index_prefixes),
"scroll": self.scroll_time,
"size": self.opts.scroll_size,
"terminate_after": self.opts.max_results,
"query": Json().convert(self.opts.query, None, None)["query"],
}
if self.opts.sort:
self.search_args["sort"] = self.opts.sort

if "_all" not in self.opts.fields:
self.search_args["_source_includes"] = ",".join(self.opts.fields)

if self.opts.debug:
logger.debug(using_indexes.format(indexes={", ".join(self.opts.index_prefixes)}))
query = json.dumps(self.opts.query, default=str)
logger.debug(using_query.format(query={query}))
logger.debug(output_fields.format(fields={", ".join(self.opts.fields)}))
logger.debug(sorting_by.format(sort=self.opts.sort))
except KeyError as e:
raise InvalidEsQueryError(query_key_missing) from e

@retry(
wait=wait_exponential(2),
Expand Down
4 changes: 4 additions & 0 deletions esxport/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,7 @@ class ScrollExpiredError(EsXportError):

class HealthCheckError(EsXportError):
"""Health check error."""


class InvalidEsQueryError(EsXportError):
"""Invalid query param."""
1 change: 1 addition & 0 deletions esxport/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
invalid_sort_format = 'Invalid input format: "{value}". Use the format "field:sort_order".'
invalid_query_format = "{value} is not a valid json string, caused {exc}"
cli_version = "EsXport Cli {__version__}"
query_key_missing = "Query key not found."
2 changes: 1 addition & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def index_name() -> str:


@pytest.fixture()
def es_index(index_name: str, elasticsearch_proc: Elasticsearch) -> Any:
def es_index(index_name: str, elasticsearch_proc: Elasticsearch) -> str:
"""Create index."""
elasticsearch_proc.indices.create(index=index_name)
return index_name
Expand Down
16 changes: 16 additions & 0 deletions test/elastic/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING

import pytest
from elastic_transport import ObjectApiResponse

from esxport.exceptions import ScrollExpiredError

Expand Down Expand Up @@ -64,3 +65,18 @@ def test_scroll_expired(self: Self, elastic_client: ElasticsearchClient) -> None
"""Test client return true when index exists."""
with pytest.raises(ScrollExpiredError):
elastic_client.scroll(scroll="5m", scroll_id="brqwdwefwef")

@pytest.mark.xdist_group(name="elastic")
def test_ping(self: Self, elastic_client: ElasticsearchClient) -> None:
"""Test that ping returns valid cluster information."""
response = elastic_client.ping()

# Assert that the response is an instance of ObjectApiResponse
assert isinstance(response, ObjectApiResponse), "Ping response should be an ObjectApiResponse."

# Convert to dictionary and check for cluster information
response_dict = response.raw
assert isinstance(response_dict, dict), "Ping response should be convertible to a dictionary."
assert "cluster_name" in response_dict, "Cluster name should be present in the ping response."
assert "version" in response_dict, "Elasticsearch version should be present in the ping response."
assert "tagline" in response_dict, "Tagline should be present in the ping response."
13 changes: 13 additions & 0 deletions test/esxport/_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing_extensions import Self

from esxport.esxport import EsXport
from esxport.exceptions import HealthCheckError


@patch("esxport.esxport.EsXport._validate_fields")
Expand Down Expand Up @@ -65,3 +66,15 @@ def test_headers_extraction(
json.dump(test_json, tmp_file)
assert esxport_obj._extract_headers() == list(test_json.keys())
TestExport.rm_export_file(f"{inspect.stack()[0].function}.csv")

def test_ping_cluster_failure(self: Self, _: Any, esxport_obj: EsXport) -> None:
"""Test that _ping_cluster raises HealthCheckError when ping fails."""
with patch.object(esxport_obj.es_client, "ping", side_effect=ConnectionError("mocked error")), pytest.raises(
HealthCheckError,
):
esxport_obj._ping_cluster()

def test_ping_cluster_success(self: Self, _: Any, esxport_obj: EsXport) -> None:
"""Test that _ping_cluster succeeds when ping is successful."""
with patch.object(esxport_obj.es_client, "ping", return_value={}):
esxport_obj._ping_cluster()
10 changes: 9 additions & 1 deletion test/esxport/_prepare_search_query_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import pytest

from esxport.exceptions import IndexNotFoundError
from esxport.exceptions import IndexNotFoundError, InvalidEsQueryError
from esxport.strings import index_not_found, output_fields, sorting_by, using_indexes

if TYPE_CHECKING:
Expand Down Expand Up @@ -147,3 +147,11 @@ def test_custom_output_fields(self: Self, _: Any, esxport_obj: EsXport) -> None:
esxport_obj.opts.fields = random_strings
esxport_obj._prepare_search_query()
assert esxport_obj.search_args["_source_includes"] == ",".join(random_strings)

def test_error_raised_when_query_key_missing(self: Self, _: Any, esxport_obj: EsXport) -> None:
"""Test if selection only some fields for the output works."""
expected_query: dict[str, Any] = {"size": 10}
esxport_obj.opts.query = expected_query

with pytest.raises(InvalidEsQueryError):
esxport_obj._prepare_search_query()

0 comments on commit 1821b1c

Please sign in to comment.