Skip to content

Commit

Permalink
Added support to other naming strategy, refactored different unrelate…
Browse files Browse the repository at this point in the history
…d stuff and added a couple of tests
  • Loading branch information
eliax1996 committed Nov 9, 2023
1 parent 101d69e commit 535d652
Show file tree
Hide file tree
Showing 12 changed files with 479 additions and 150 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ Keys to take special care are the ones needed to configure Kafka and advertised_
- Runtime directory for the ``protoc`` protobuf schema parser and code generator
* - ``name_strategy``
- ``topic_name``
- Name strategy to use when storing schemas from the kafka rest proxy service
- Name strategy to use when storing schemas from the kafka rest proxy service. You can opt between ``name_strategy`` , ``record_name`` and ``topic_record_name``
* - ``name_strategy_validation``
- ``true``
- If enabled, validate that given schema is registered under used name strategy when producing messages from Kafka Rest
Expand Down
21 changes: 5 additions & 16 deletions karapace/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
"""
from __future__ import annotations

from enum import Enum, unique
from karapace.constants import DEFAULT_AIOHTTP_CLIENT_MAX_SIZE, DEFAULT_PRODUCER_MAX_REQUEST, DEFAULT_SCHEMA_TOPIC
from karapace.typing import ElectionStrategy, NameStrategy
from karapace.utils import json_decode, json_encode, JSONDecodeError
from pathlib import Path
from typing import IO, Mapping
Expand Down Expand Up @@ -158,19 +158,6 @@ class InvalidConfiguration(Exception):
pass


@unique
class ElectionStrategy(Enum):
highest = "highest"
lowest = "lowest"


@unique
class NameStrategy(Enum):
topic_name = "topic_name"
record_name = "record_name"
topic_record_name = "topic_record_name"


def parse_env_value(value: str) -> str | int | bool:
# we only have ints, strings and bools in the config
try:
Expand Down Expand Up @@ -273,8 +260,10 @@ def validate_config(config: Config) -> None:
try:
NameStrategy(name_strategy)
except ValueError:
valid_strategies = [strategy.value for strategy in NameStrategy]
raise InvalidConfiguration(f"Invalid name strategy: {name_strategy}, valid values are {valid_strategies}") from None
valid_strategies = list(NameStrategy)
raise InvalidConfiguration(
f"Invalid default name strategy: {name_strategy}, valid values are {valid_strategies}"
) from None

if config["rest_authorization"] and config["sasl_bootstrap_uri"] is None:
raise InvalidConfiguration(
Expand Down
58 changes: 37 additions & 21 deletions karapace/kafka_rest_apis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,14 @@
from karapace.rapu import HTTPRequest, JSON_CONTENT_TYPE
from karapace.schema_models import TypedSchema, ValidatedTypedSchema
from karapace.schema_type import SchemaType
from karapace.serialization import InvalidMessageSchema, InvalidPayload, SchemaRegistrySerializer, SchemaRetrievalError
from karapace.typing import SchemaId, Subject
from karapace.serialization import (
get_subject_name,
InvalidMessageSchema,
InvalidPayload,
SchemaRegistrySerializer,
SchemaRetrievalError,
)
from karapace.typing import NameStrategy, SchemaId, Subject, SubjectType
from karapace.utils import convert_to_int, json_encode, KarapaceKafkaClient
from typing import Callable, Dict, List, Optional, Tuple, Union

Expand All @@ -39,7 +45,7 @@
import logging
import time

RECORD_KEYS = ["key", "value", "partition"]
SUBJECT_VALID_POSTFIX = [SubjectType.key, SubjectType.value]
PUBLISH_KEYS = {"records", "value_schema", "value_schema_id", "key_schema", "key_schema_id"}
RECORD_CODES = [42201, 42202]
KNOWN_FORMATS = {"json", "avro", "protobuf", "binary"}
Expand Down Expand Up @@ -439,6 +445,7 @@ def __init__(

self._async_producer_lock = asyncio.Lock()
self._async_producer: Optional[AIOKafkaProducer] = None
self.naming_strategy = NameStrategy(self.config["name_strategy"])

def __str__(self) -> str:
return f"UserRestProxy(username={self.config['sasl_plain_username']})"
Expand Down Expand Up @@ -759,7 +766,7 @@ async def get_schema_id(
self,
data: dict,
topic: str,
prefix: str,
subject_type: SubjectType,
schema_type: SchemaType,
) -> SchemaId:
"""
Expand All @@ -770,21 +777,27 @@ async def get_schema_id(
"""
log.debug("[resolve schema id] Retrieving schema id for %r", data)
schema_id: Union[SchemaId, None] = (
SchemaId(int(data[f"{prefix}_schema_id"])) if f"{prefix}_schema_id" in data else None
SchemaId(int(data[f"{subject_type}_schema_id"])) if f"{subject_type}_schema_id" in data else None
)
schema_str = data.get(f"{prefix}_schema")
schema_str = data.get(f"{subject_type}_schema")

if schema_id is None and schema_str is None:
raise InvalidSchema()

if schema_id is None:
parsed_schema = ValidatedTypedSchema.parse(schema_type, schema_str)
subject_name = self.serializer.get_subject_name(topic, parsed_schema, prefix, schema_type)

subject_name = get_subject_name(
topic,
parsed_schema,
subject_type,
self.naming_strategy,
)
schema_id = await self._query_schema_id_from_cache_or_registry(parsed_schema, schema_str, subject_name)
else:

def subject_not_included(schema: TypedSchema, subjects: List[Subject]) -> bool:
subject = self.serializer.get_subject_name(topic, schema, prefix, schema_type)
subject = get_subject_name(topic, schema, subject_type, self.naming_strategy)
return subject not in subjects

parsed_schema, valid_subjects = await self._query_schema_and_subjects(
Expand Down Expand Up @@ -833,7 +846,9 @@ async def _query_schema_id_from_cache_or_registry(
)
return schema_id

async def validate_schema_info(self, data: dict, prefix: str, content_type: str, topic: str, schema_type: str):
async def validate_schema_info(
self, data: dict, subject_type: SubjectType, content_type: str, topic: str, schema_type: str
):
try:
schema_type = SCHEMA_MAPPINGS[schema_type]
except KeyError:
Expand All @@ -848,7 +863,7 @@ async def validate_schema_info(self, data: dict, prefix: str, content_type: str,

# will do in place updates of id keys, since calling these twice would be expensive
try:
data[f"{prefix}_schema_id"] = await self.get_schema_id(data, topic, prefix, schema_type)
data[f"{subject_type}_schema_id"] = await self.get_schema_id(data, topic, subject_type, schema_type)
except InvalidPayload:
log.exception("Unable to retrieve schema id")
KafkaRest.r(
Expand All @@ -863,16 +878,17 @@ async def validate_schema_info(self, data: dict, prefix: str, content_type: str,
KafkaRest.r(
body={
"error_code": RESTErrorCodes.SCHEMA_RETRIEVAL_ERROR.value,
"message": f"Error when registering schema. format = {schema_type.value}, subject = {topic}-{prefix}",
"message": f"Error when registering schema."
f"format = {schema_type.value}, subject = {topic}-{subject_type}",
},
content_type=content_type,
status=HTTPStatus.REQUEST_TIMEOUT,
)
except InvalidSchema:
if f"{prefix}_schema" in data:
err = f'schema = {data[f"{prefix}_schema"]}'
if f"{subject_type}_schema" in data:
err = f'schema = {data[f"{subject_type}_schema"]}'
else:
err = f'schema_id = {data[f"{prefix}_schema_id"]}'
err = f'schema_id = {data[f"{subject_type}_schema_id"]}'
KafkaRest.r(
body={
"error_code": RESTErrorCodes.INVALID_DATA.value,
Expand Down Expand Up @@ -1002,26 +1018,26 @@ async def validate_publish_request_format(self, data: dict, formats: dict, conte
status=HTTPStatus.BAD_REQUEST,
)
convert_to_int(r, "partition", content_type)
if set(r.keys()).difference(RECORD_KEYS):
if set(r.keys()).difference({subject_type.value for subject_type in SubjectType}):
KafkaRest.unprocessable_entity(
message="Invalid request format",
content_type=content_type,
sub_code=RESTErrorCodes.HTTP_UNPROCESSABLE_ENTITY.value,
)
# disallow missing id and schema for any key/value list that has at least one populated element
if formats["embedded_format"] in {"avro", "jsonschema", "protobuf"}:
for prefix, code in zip(RECORD_KEYS, RECORD_CODES):
if self.all_empty(data, prefix):
for subject_type, code in zip(SUBJECT_VALID_POSTFIX, RECORD_CODES):
if self.all_empty(data, subject_type):
continue
if not self.is_valid_schema_request(data, prefix):
if not self.is_valid_schema_request(data, subject_type):
KafkaRest.unprocessable_entity(
message=f"Request includes {prefix}s and uses a format that requires schemas "
f"but does not include the {prefix}_schema or {prefix}_schema_id fields",
message=f"Request includes {subject_type}s and uses a format that requires schemas "
f"but does not include the {subject_type}_schema or {subject_type.value}_schema_id fields",
content_type=content_type,
sub_code=code,
)
try:
await self.validate_schema_info(data, prefix, content_type, topic, formats["embedded_format"])
await self.validate_schema_info(data, subject_type, content_type, topic, formats["embedded_format"])
except InvalidMessageSchema as e:
KafkaRest.unprocessable_entity(
message=str(e),
Expand Down
58 changes: 42 additions & 16 deletions karapace/protobuf/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
Copyright (c) 2023 Aiven Ltd
See LICENSE for details
"""

from __future__ import annotations

from karapace.dataclasses import default_dataclass

# Ported from square/wire:
Expand All @@ -21,7 +24,7 @@
from karapace.protobuf.type_element import TypeElement
from karapace.protobuf.utils import append_documentation, append_indented
from karapace.schema_references import Reference
from typing import Iterable, List, Mapping, Optional, Sequence, Set, Tuple
from typing import Iterable, Mapping, Sequence

import itertools

Expand Down Expand Up @@ -126,10 +129,10 @@ class SourceFileReference:
@default_dataclass
class TypeTree:
token: str
children: List["TypeTree"]
source_reference: Optional[SourceFileReference]
children: list[TypeTree]
source_reference: SourceFileReference | None

def source_reference_tree_recursive(self) -> Iterable[Optional[SourceFileReference]]:
def source_reference_tree_recursive(self) -> Iterable[SourceFileReference | None]:
sources = [] if self.source_reference is None else [self.source_reference]
for child in self.children:
sources = itertools.chain(sources, child.source_reference_tree())
Expand Down Expand Up @@ -201,7 +204,7 @@ def __repr__(self) -> str:

def _add_new_type_recursive(
parent_tree: TypeTree,
remaining_tokens: List[str],
remaining_tokens: list[str],
file: str,
inserted_elements: int,
) -> None:
Expand Down Expand Up @@ -249,8 +252,8 @@ class ProtobufSchema:
def __init__(
self,
schema: str,
references: Optional[Sequence[Reference]] = None,
dependencies: Optional[Mapping[str, Dependency]] = None,
references: Sequence[Reference] | None = None,
dependencies: Mapping[str, Dependency] | None = None,
) -> None:
if type(schema).__name__ != "str":
raise IllegalArgumentException("Non str type of schema string")
Expand All @@ -260,7 +263,7 @@ def __init__(
self.references = references
self.dependencies = dependencies

def type_in_tree(self, tree: TypeTree, remaining_tokens: List[str]) -> Optional[TypeTree]:
def type_in_tree(self, tree: TypeTree, remaining_tokens: list[str]) -> TypeTree | None:
if remaining_tokens:
to_seek = remaining_tokens.pop()

Expand All @@ -270,10 +273,33 @@ def type_in_tree(self, tree: TypeTree, remaining_tokens: List[str]) -> Optional[
return None
return tree

def type_exist_in_tree(self, tree: TypeTree, remaining_tokens: List[str]) -> bool:
def record_name(self) -> str | None:
if len(self.proto_file_element.types) == 0:
return None

package_name = (
self.proto_file_element.package_name + "." if self.proto_file_element.package_name not in [None, ""] else ""
)

first_element = None
first_enum = None

for inspected_type in self.proto_file_element.types:
if isinstance(inspected_type, MessageElement):
first_element = inspected_type
break

if first_enum is None and isinstance(inspected_type, EnumElement):
first_enum = inspected_type

naming_element = first_element if first_element is not None else first_enum

return package_name + naming_element.name

def type_exist_in_tree(self, tree: TypeTree, remaining_tokens: list[str]) -> bool:
return self.type_in_tree(tree, remaining_tokens) is not None

def recursive_imports(self) -> Set[str]:
def recursive_imports(self) -> set[str]:
imports = set(self.proto_file_element.imports)

if self.dependencies:
Expand All @@ -282,7 +308,7 @@ def recursive_imports(self) -> Set[str]:

return imports

def are_type_usage_valid(self, root_type_tree: TypeTree, used_types: List[UsedType]) -> Tuple[bool, Optional[str]]:
def are_type_usage_valid(self, root_type_tree: TypeTree, used_types: list[UsedType]) -> tuple[bool, str | None]:
# Please note that this check only ensures the requested type exists. However, for performance reasons, it works in
# the opposite way of how specificity works in Protobuf. In Protobuf, the type is matched not only to check if it
# exists, but also based on the order of search: local definition comes before imported types. In this code, we
Expand Down Expand Up @@ -408,7 +434,7 @@ def types_tree(self) -> TypeTree:
return root_tree

@staticmethod
def used_type(parent: str, element_type: str) -> List[UsedType]:
def used_type(parent: str, element_type: str) -> list[UsedType]:
if element_type.find("map<") == 0:
end = element_type.find(">")
virgule = element_type.find(",")
Expand All @@ -426,7 +452,7 @@ def dependencies_one_of(
package_name: str,
parent_name: str,
one_of: OneOfElement,
) -> List[UsedType]:
) -> list[UsedType]:
parent = package_name + "." + parent_name
dependencies = []
for field in one_of.fields:
Expand All @@ -438,7 +464,7 @@ def dependencies_one_of(
)
return dependencies

def used_types(self) -> List[UsedType]:
def used_types(self) -> list[UsedType]:
dependencies_used_types = []
if self.dependencies:
for key in self.dependencies:
Expand Down Expand Up @@ -469,7 +495,7 @@ def nested_used_type(
package_name: str,
parent_name: str,
element_type: TypeElement,
) -> List[str]:
) -> list[str]:
used_types = []

if isinstance(element_type, MessageElement):
Expand Down Expand Up @@ -540,7 +566,7 @@ def to_schema(self) -> str:

return "".join(strings)

def compare(self, other: "ProtobufSchema", result: CompareResult) -> CompareResult:
def compare(self, other: ProtobufSchema, result: CompareResult) -> CompareResult:
return self.proto_file_element.compare(
other.proto_file_element,
result,
Expand Down
Loading

0 comments on commit 535d652

Please sign in to comment.