diff --git a/README.rst b/README.rst index 93875719a..c18d982c6 100644 --- a/README.rst +++ b/README.rst @@ -459,12 +459,12 @@ Keys to take special care are the ones needed to configure Kafka and advertised_ * - ``protobuf_runtime_directory`` - ``runtime`` - Runtime directory for the ``protoc`` protobuf schema parser and code generator - * - ``name_strategy`` - - ``topic_name`` - - Name strategy to use when storing schemas from the kafka rest proxy service + * - ``default_name_strategy`` + - ``topic_name``, ``record_name``, ``topic_record_name``, ``no_validation`` + - Default name strategy to use when storing schemas from the kafka rest proxy service, could be overriden for each topic by calling the `/topic/{topic}/name_strategy/{strategy}` endpoint * - ``name_strategy_validation`` - ``true`` - - If enabled, validate that given schema is registered under used name strategy when producing messages from Kafka Rest + - If enabled, validate that given schema is registered under the expected subjects requireds by the specified name strategy (default or overridden) when producing messages from Kafka Rest * - ``master_election_strategy`` - ``lowest`` - Decides on what basis the Karapace cluster master is chosen (only relevant in a multi node setup) diff --git a/karapace/config.py b/karapace/config.py index c87275a8f..799c26450 100644 --- a/karapace/config.py +++ b/karapace/config.py @@ -74,7 +74,7 @@ class Config(TypedDict): session_timeout_ms: int karapace_rest: bool karapace_registry: bool - name_strategy: str + default_name_strategy: str name_strategy_validation: bool master_election_strategy: str protobuf_runtime_directory: str @@ -146,7 +146,7 @@ class ConfigDefaults(Config, total=False): "session_timeout_ms": 10000, "karapace_rest": False, "karapace_registry": False, - "name_strategy": "topic_name", + "default_name_strategy": "topic_name", "name_strategy_validation": True, "master_election_strategy": "lowest", "protobuf_runtime_directory": "runtime", @@ -158,6 +158,11 @@ class InvalidConfiguration(Exception): pass +class StrEnum(str, Enum): + def __str__(self) -> str: + return str(self.value) + + @unique class ElectionStrategy(Enum): highest = "highest" @@ -165,10 +170,18 @@ class ElectionStrategy(Enum): @unique -class NameStrategy(Enum): +class NameStrategy(StrEnum): topic_name = "topic_name" record_name = "record_name" topic_record_name = "topic_record_name" + no_validation = "no_validation_strategy" + + +@unique +class SubjectType(StrEnum): + key = "key" + value = "value" + partition = "partition" def parse_env_value(value: str) -> str | int | bool: @@ -269,12 +282,14 @@ def validate_config(config: Config) -> None: f"Invalid master election strategy: {master_election_strategy}, valid values are {valid_strategies}" ) from None - name_strategy = config["name_strategy"] + deafault_name_strategy = config["default_name_strategy"] try: - NameStrategy(name_strategy) + NameStrategy(deafault_name_strategy) except ValueError: - valid_strategies = [strategy.value for strategy in NameStrategy] - raise InvalidConfiguration(f"Invalid name strategy: {name_strategy}, valid values are {valid_strategies}") from None + valid_strategies = list(NameStrategy) + raise InvalidConfiguration( + f"Invalid default name strategy: {deafault_name_strategy}, valid values are {valid_strategies}" + ) from None if config["rest_authorization"] and config["sasl_bootstrap_uri"] is None: raise InvalidConfiguration( diff --git a/karapace/in_memory_database.py b/karapace/in_memory_database.py index 222e38046..f1fa5db17 100644 --- a/karapace/in_memory_database.py +++ b/karapace/in_memory_database.py @@ -7,9 +7,10 @@ from __future__ import annotations from dataclasses import dataclass, field +from karapace.config import NameStrategy from karapace.schema_models import SchemaVersion, TypedSchema from karapace.schema_references import Reference, Referents -from karapace.typing import ResolvedVersion, SchemaId, Subject +from karapace.typing import ResolvedVersion, SchemaId, Subject, TopicName from threading import Lock, RLock from typing import Iterable, Sequence @@ -32,6 +33,7 @@ def __init__(self) -> None: self.schemas: dict[SchemaId, TypedSchema] = {} self.schema_lock_thread = RLock() self.referenced_by: dict[tuple[Subject, ResolvedVersion], Referents] = {} + self.topic_validation_strategies: dict[TopicName, NameStrategy] = {} # Content based deduplication of schemas. This is used to reduce memory # usage when the same schema is produce multiple times to the same or @@ -229,6 +231,15 @@ def find_subject_schemas(self, *, subject: Subject, include_deleted: bool) -> di if schema_version.deleted is False } + def get_topic_strategy(self, *, topic_name: TopicName) -> NameStrategy | None: + if topic_name not in self.topic_validation_strategies: + return None + + return self.topic_validation_strategies[topic_name] + + def override_topic_strategy(self, *, topic_name: TopicName, name_strategy: NameStrategy) -> None: + self.topic_validation_strategies[topic_name] = name_strategy + def delete_subject(self, *, subject: Subject, version: ResolvedVersion) -> None: with self.schema_lock_thread: for schema_version in self.subjects[subject].schemas.values(): diff --git a/karapace/kafka_rest_apis/__init__.py b/karapace/kafka_rest_apis/__init__.py index c63194e52..ed3dd7630 100644 --- a/karapace/kafka_rest_apis/__init__.py +++ b/karapace/kafka_rest_apis/__init__.py @@ -13,7 +13,7 @@ TopicAuthorizationFailedError, UnknownTopicOrPartitionError, ) -from karapace.config import Config, create_client_ssl_context +from karapace.config import Config, create_client_ssl_context, NameStrategy, SubjectType from karapace.errors import InvalidSchema from karapace.kafka_rest_apis.admin import KafkaRestAdminClient from karapace.kafka_rest_apis.authentication import ( @@ -28,8 +28,14 @@ from karapace.rapu import HTTPRequest, JSON_CONTENT_TYPE from karapace.schema_models import TypedSchema, ValidatedTypedSchema from karapace.schema_type import SchemaType -from karapace.serialization import InvalidMessageSchema, InvalidPayload, SchemaRegistrySerializer, SchemaRetrievalError -from karapace.typing import SchemaId, Subject +from karapace.serialization import ( + get_subject_name, + InvalidMessageSchema, + InvalidPayload, + SchemaRegistrySerializer, + SchemaRetrievalError, +) +from karapace.typing import SchemaId, Subject, TopicName from karapace.utils import convert_to_int, json_encode, KarapaceKafkaClient from typing import Callable, Dict, List, Optional, Tuple, Union @@ -39,7 +45,7 @@ import logging import time -RECORD_KEYS = ["key", "value", "partition"] +SUBJECT_VALID_POSTFIX = [SubjectType.key, SubjectType.value] PUBLISH_KEYS = {"records", "value_schema", "value_schema_id", "key_schema", "key_schema_id"} RECORD_CODES = [42201, 42202] KNOWN_FORMATS = {"json", "avro", "protobuf", "binary"} @@ -759,7 +765,7 @@ async def get_schema_id( self, data: dict, topic: str, - prefix: str, + subject_type: SubjectType, schema_type: SchemaType, ) -> SchemaId: """ @@ -770,21 +776,23 @@ async def get_schema_id( """ log.debug("[resolve schema id] Retrieving schema id for %r", data) schema_id: Union[SchemaId, None] = ( - SchemaId(int(data[f"{prefix}_schema_id"])) if f"{prefix}_schema_id" in data else None + SchemaId(int(data[f"{subject_type}_schema_id"])) if f"{subject_type}_schema_id" in data else None ) - schema_str = data.get(f"{prefix}_schema") + schema_str = data.get(f"{subject_type}_schema") + naming_strategy = await self.serializer.get_topic_strategy_name(topic_name=TopicName(topic)) if schema_id is None and schema_str is None: raise InvalidSchema() if schema_id is None: parsed_schema = ValidatedTypedSchema.parse(schema_type, schema_str) - subject_name = self.serializer.get_subject_name(topic, parsed_schema, prefix, schema_type) + + subject_name = get_subject_name(topic, parsed_schema, subject_type, naming_strategy) schema_id = await self._query_schema_id_from_cache_or_registry(parsed_schema, schema_str, subject_name) else: def subject_not_included(schema: TypedSchema, subjects: List[Subject]) -> bool: - subject = self.serializer.get_subject_name(topic, schema, prefix, schema_type) + subject = get_subject_name(topic, schema, subject_type, naming_strategy) return subject not in subjects parsed_schema, valid_subjects = await self._query_schema_and_subjects( @@ -792,7 +800,11 @@ def subject_not_included(schema: TypedSchema, subjects: List[Subject]) -> bool: need_new_call=subject_not_included, ) - if self.config["name_strategy_validation"] and subject_not_included(parsed_schema, valid_subjects): + if ( + self.config["name_strategy_validation"] + and naming_strategy != NameStrategy.no_validation + and subject_not_included(parsed_schema, valid_subjects) + ): raise InvalidSchema() return schema_id @@ -833,7 +845,9 @@ async def _query_schema_id_from_cache_or_registry( ) return schema_id - async def validate_schema_info(self, data: dict, prefix: str, content_type: str, topic: str, schema_type: str): + async def validate_schema_info( + self, data: dict, subject_type: SubjectType, content_type: str, topic: str, schema_type: str + ): try: schema_type = SCHEMA_MAPPINGS[schema_type] except KeyError: @@ -848,7 +862,7 @@ async def validate_schema_info(self, data: dict, prefix: str, content_type: str, # will do in place updates of id keys, since calling these twice would be expensive try: - data[f"{prefix}_schema_id"] = await self.get_schema_id(data, topic, prefix, schema_type) + data[f"{subject_type}_schema_id"] = await self.get_schema_id(data, topic, subject_type, schema_type) except InvalidPayload: log.exception("Unable to retrieve schema id") KafkaRest.r( @@ -863,16 +877,17 @@ async def validate_schema_info(self, data: dict, prefix: str, content_type: str, KafkaRest.r( body={ "error_code": RESTErrorCodes.SCHEMA_RETRIEVAL_ERROR.value, - "message": f"Error when registering schema. format = {schema_type.value}, subject = {topic}-{prefix}", + "message": f"Error when registering schema." + f"format = {schema_type.value}, subject = {topic}-{subject_type}", }, content_type=content_type, status=HTTPStatus.REQUEST_TIMEOUT, ) except InvalidSchema: - if f"{prefix}_schema" in data: - err = f'schema = {data[f"{prefix}_schema"]}' + if f"{subject_type}_schema" in data: + err = f'schema = {data[f"{subject_type}_schema"]}' else: - err = f'schema_id = {data[f"{prefix}_schema_id"]}' + err = f'schema_id = {data[f"{subject_type}_schema_id"]}' KafkaRest.r( body={ "error_code": RESTErrorCodes.INVALID_DATA.value, @@ -1002,7 +1017,7 @@ async def validate_publish_request_format(self, data: dict, formats: dict, conte status=HTTPStatus.BAD_REQUEST, ) convert_to_int(r, "partition", content_type) - if set(r.keys()).difference(RECORD_KEYS): + if set(r.keys()).difference({subject_type.value for subject_type in SubjectType}): KafkaRest.unprocessable_entity( message="Invalid request format", content_type=content_type, @@ -1010,18 +1025,18 @@ async def validate_publish_request_format(self, data: dict, formats: dict, conte ) # disallow missing id and schema for any key/value list that has at least one populated element if formats["embedded_format"] in {"avro", "jsonschema", "protobuf"}: - for prefix, code in zip(RECORD_KEYS, RECORD_CODES): - if self.all_empty(data, prefix): + for subject_type, code in zip(SUBJECT_VALID_POSTFIX, RECORD_CODES): + if self.all_empty(data, subject_type): continue - if not self.is_valid_schema_request(data, prefix): + if not self.is_valid_schema_request(data, subject_type): KafkaRest.unprocessable_entity( - message=f"Request includes {prefix}s and uses a format that requires schemas " - f"but does not include the {prefix}_schema or {prefix}_schema_id fields", + message=f"Request includes {subject_type}s and uses a format that requires schemas " + f"but does not include the {subject_type}_schema or {subject_type.value}_schema_id fields", content_type=content_type, sub_code=code, ) try: - await self.validate_schema_info(data, prefix, content_type, topic, formats["embedded_format"]) + await self.validate_schema_info(data, subject_type, content_type, topic, formats["embedded_format"]) except InvalidMessageSchema as e: KafkaRest.unprocessable_entity( message=str(e), diff --git a/karapace/schema_reader.py b/karapace/schema_reader.py index 3dec4a887..f4c78330f 100644 --- a/karapace/schema_reader.py +++ b/karapace/schema_reader.py @@ -8,6 +8,7 @@ from avro.schema import Schema as AvroSchema from contextlib import closing, ExitStack +from enum import Enum from jsonschema.validators import Draft7Validator from kafka import KafkaConsumer, TopicPartition from kafka.admin import KafkaAdminClient, NewTopic @@ -20,7 +21,7 @@ TopicAlreadyExistsError, ) from karapace import constants -from karapace.config import Config +from karapace.config import Config, NameStrategy from karapace.dependency import Dependency from karapace.errors import InvalidReferences, InvalidSchema from karapace.in_memory_database import InMemoryDatabase @@ -31,7 +32,7 @@ from karapace.schema_models import parse_protobuf_schema_definition, SchemaType, TypedSchema, ValidatedTypedSchema from karapace.schema_references import LatestVersionReference, Reference, reference_from_mapping, Referents from karapace.statsd import StatsClient -from karapace.typing import JsonObject, ResolvedVersion, SchemaId, Subject +from karapace.typing import JsonObject, ResolvedVersion, SchemaId, Subject, TopicName from karapace.utils import json_decode, JSONDecodeError, KarapaceKafkaClient from threading import Event, Thread from typing import Final, Mapping, Sequence @@ -58,6 +59,14 @@ METRIC_SUBJECT_DATA_SCHEMA_VERSIONS_GAUGE: Final = "karapace_schema_reader_subject_data_schema_versions" +class MessageType(Enum): + config = "CONFIG" + schema = "SCHEMA" + delete_subject = "DELETE_SUBJECT" + schema_strategy = "SCHEMA_STRATEGY" + no_operation = "NOOP" + + def _create_consumer_from_config(config: Config) -> KafkaConsumer: # Group not set on purpose, all consumers read the same data session_timeout_ms = config["session_timeout_ms"] @@ -429,6 +438,11 @@ def _handle_msg_delete_subject(self, key: dict, value: dict | None) -> None: # LOG.info("Deleting subject: %r, value: %r", subject, value) self.database.delete_subject(subject=subject, version=version) + def _handle_msg_schema_strategy(self, key: dict, value: dict | None) -> None: # pylint: disable=unused-argument + assert isinstance(value, dict) + topic, strategy = value["topic"], value["strategy"] + self.database.override_topic_strategy(topic_name=TopicName(topic), name_strategy=NameStrategy(strategy)) + def _handle_msg_schema_hard_delete(self, key: dict) -> None: subject, version = key["subject"], key["version"] @@ -522,14 +536,27 @@ def _handle_msg_schema(self, key: dict, value: dict | None) -> None: self.database.insert_referenced_by(subject=ref.subject, version=ref.version, schema_id=schema_id) def handle_msg(self, key: dict, value: dict | None) -> None: - if key["keytype"] == "CONFIG": - self._handle_msg_config(key, value) - elif key["keytype"] == "SCHEMA": - self._handle_msg_schema(key, value) - elif key["keytype"] == "DELETE_SUBJECT": - self._handle_msg_delete_subject(key, value) - elif key["keytype"] == "NOOP": # for spec completeness - pass + if "keytype" in key: + try: + message_type = MessageType(key["keytype"]) + + if message_type == MessageType.config: + self._handle_msg_config(key, value) + elif message_type == MessageType.schema: + self._handle_msg_schema(key, value) + elif message_type == MessageType.delete_subject: + self._handle_msg_delete_subject(key, value) + elif message_type == MessageType.schema_strategy: + self._handle_msg_schema_strategy(key, value) + elif message_type == MessageType.no_operation: + pass + except ValueError: + LOG.error("The message %s-%s has been discarded because the %s is not managed", key, value, key["keytype"]) + + else: + LOG.error( + "The message %s-%s has been discarded because doesn't contain the `keytype` key in the key", key, value + ) def remove_referenced_by( self, diff --git a/karapace/schema_registry.py b/karapace/schema_registry.py index 867eeb633..604892a38 100644 --- a/karapace/schema_registry.py +++ b/karapace/schema_registry.py @@ -7,7 +7,7 @@ from contextlib import AsyncExitStack, closing from karapace.compatibility import check_compatibility, CompatibilityModes from karapace.compatibility.jsonschema.checks import is_incompatible -from karapace.config import Config +from karapace.config import Config, NameStrategy from karapace.dependency import Dependency from karapace.errors import ( IncompatibleSchema, @@ -27,9 +27,9 @@ from karapace.messaging import KarapaceProducer from karapace.offset_watcher import OffsetWatcher from karapace.schema_models import ParsedTypedSchema, SchemaType, SchemaVersion, TypedSchema, ValidatedTypedSchema -from karapace.schema_reader import KafkaSchemaReader +from karapace.schema_reader import KafkaSchemaReader, MessageType from karapace.schema_references import LatestVersionReference, Reference -from karapace.typing import JsonObject, ResolvedVersion, SchemaId, Subject, Version +from karapace.typing import JsonObject, ResolvedVersion, SchemaId, Subject, TopicName, Version from typing import Mapping, Sequence import asyncio @@ -466,6 +466,20 @@ def send_schema_message( value = None self.producer.send_message(key=key, value=value) + def get_validation_strategy_for_topic(self, *, topic_name: TopicName) -> NameStrategy: + strategy = self.database.get_topic_strategy(topic_name=topic_name) + return strategy if strategy is not None else NameStrategy(self.config["default_name_strategy"]) + + def send_validation_strategy_for_topic( + self, + *, + topic_name: TopicName, + validation_strategy: NameStrategy, + ) -> None: + key = {"topic": topic_name, "keytype": MessageType.schema_strategy.value, "magic": 0} + value = {"strategy": validation_strategy.value, "topic": topic_name} + self.producer.send_message(key=key, value=value) + def send_config_message(self, compatibility_level: CompatibilityModes, subject: Subject | None = None) -> None: key = {"subject": subject, "magic": 0, "keytype": "CONFIG"} value = {"compatibilityLevel": compatibility_level.value} diff --git a/karapace/schema_registry_apis.py b/karapace/schema_registry_apis.py index f4d22cd78..07af2d1ea 100644 --- a/karapace/schema_registry_apis.py +++ b/karapace/schema_registry_apis.py @@ -11,7 +11,7 @@ from karapace.auth import HTTPAuthorizer, Operation, User from karapace.compatibility import check_compatibility, CompatibilityModes from karapace.compatibility.jsonschema.checks import is_incompatible -from karapace.config import Config +from karapace.config import Config, NameStrategy from karapace.errors import ( IncompatibleSchema, InvalidReferences, @@ -28,13 +28,13 @@ SubjectSoftDeletedException, VersionNotFoundException, ) -from karapace.karapace import KarapaceBase +from karapace.karapace import empty_response, KarapaceBase from karapace.protobuf.exception import ProtobufUnresolvedDependencyException from karapace.rapu import HTTPRequest, JSON_CONTENT_TYPE, SERVER_NAME from karapace.schema_models import ParsedTypedSchema, SchemaType, SchemaVersion, TypedSchema, ValidatedTypedSchema from karapace.schema_references import LatestVersionReference, Reference, reference_from_mapping from karapace.schema_registry import KarapaceSchemaRegistry, validate_version -from karapace.typing import JsonData, JsonObject, ResolvedVersion, SchemaId +from karapace.typing import JsonData, JsonObject, ResolvedVersion, SchemaId, TopicName from karapace.utils import JSONDecodeError from typing import Any @@ -301,6 +301,23 @@ def _add_schema_registry_routes(self) -> None: json_body=False, auth=self._auth, ) + self.route( + "/topic//name_strategy", + callback=self.subject_validation_strategy_get, + method="GET", + schema_request=True, + json_body=False, + auth=None, + ) + self.route( + "/topic//name_strategy/", + callback=self.subject_validation_strategy_set, + method="POST", + schema_request=True, + with_request=True, + json_body=False, + auth=None, + ) async def close(self) -> None: async with AsyncExitStack() as stack: @@ -985,6 +1002,38 @@ def _validate_schema_type(self, content_type: str, data: JsonData) -> SchemaType ) return schema_type + def _validate_topic_name(self, topic: str) -> TopicName: + valid_topic_names = self.schema_registry.schema_reader.admin_client.list_topics() + + if topic in valid_topic_names: + return TopicName(topic) + + self.r( + body={ + "error_code": SchemaErrorCodes.HTTP_UNPROCESSABLE_ENTITY.value, + "message": f"The topic {topic} isn't existing, proceed with creating it first", + }, + content_type=JSON_CONTENT_TYPE, + status=HTTPStatus.UNPROCESSABLE_ENTITY, + ) + + def _validate_name_strategy(self, name_strategy: str) -> NameStrategy: + try: + strategy = NameStrategy(name_strategy) + return strategy + except ValueError: + valid_strategies = list(NameStrategy) + error_message = f"Invalid name strategy: {name_strategy}, valid values are {valid_strategies}" + + self.r( + body={ + "error_code": SchemaErrorCodes.HTTP_UNPROCESSABLE_ENTITY.value, + "message": error_message, + }, + content_type=JSON_CONTENT_TYPE, + status=HTTPStatus.UNPROCESSABLE_ENTITY, + ) + def _validate_schema_key(self, content_type: str, body: dict) -> None: if "schema" not in body: self.r( @@ -1238,6 +1287,44 @@ async def subject_post( url = f"{master_url}/subjects/{subject}/versions" await self._forward_request_remote(request=request, body=body, url=url, content_type=content_type, method="POST") + async def subject_validation_strategy_get(self, content_type: str, *, topic: str) -> None: + strategy_name = self.schema_registry.get_validation_strategy_for_topic(topic_name=TopicName(topic)).value + reply = {"strategy": strategy_name} + self.r(reply, content_type) + + async def subject_validation_strategy_set( + self, + content_type: str, + request: HTTPRequest, + *, + topic: str, + strategy: str, + ) -> None: + # proceeding with the strategy first since it's cheaper + strategy_name = self._validate_name_strategy(strategy) + # real validation of the topic name commented, do we need to do that? does it make sense? + topic_name = TopicName(topic) # self._validate_topic_name(topic) + + are_we_master, master_url = await self.schema_registry.get_master() + if are_we_master: + self.schema_registry.send_validation_strategy_for_topic( + topic_name=topic_name, + validation_strategy=strategy_name, + ) + empty_response() + else: + # I don't really like it, in theory we should parse the URL and change only the host portion while + # keeping the rest the same + url = f"{master_url}/topic/{topic}/name_strategy" + + await self._forward_request_remote( + request=request, + body=None, + url=url, + content_type=content_type, + method="POST", + ) + def get_schema_id_if_exists(self, *, subject: str, schema: TypedSchema, include_deleted: bool) -> SchemaId | None: schema_id = self.schema_registry.database.get_schema_id_if_exists( subject=subject, schema=schema, include_deleted=include_deleted diff --git a/karapace/serialization.py b/karapace/serialization.py index 29dc51a6c..8e4346856 100644 --- a/karapace/serialization.py +++ b/karapace/serialization.py @@ -9,13 +9,15 @@ from google.protobuf.message import DecodeError from jsonschema import ValidationError from karapace.client import Client +from karapace.config import NameStrategy from karapace.dependency import Dependency from karapace.errors import InvalidReferences +from karapace.kafka_rest_apis import SubjectType from karapace.protobuf.exception import ProtobufTypeException from karapace.protobuf.io import ProtobufDatumReader, ProtobufDatumWriter from karapace.schema_models import InvalidSchema, ParsedTypedSchema, SchemaType, TypedSchema, ValidatedTypedSchema from karapace.schema_references import LatestVersionReference, Reference, reference_from_mapping -from karapace.typing import ResolvedVersion, SchemaId, Subject +from karapace.typing import ResolvedVersion, SchemaId, Subject, TopicName from karapace.utils import json_decode, json_encode from typing import Any, Callable, Dict, List, MutableMapping, Optional, Set, Tuple from urllib.parse import quote @@ -71,10 +73,15 @@ def topic_record_name_strategy(topic_name: str, record_name: str) -> str: return topic_name + "-" + record_name +def no_validation_strategy(topic_name: str, record_name: str) -> str: + return f"__auto_registration_anonymous_{topic_record_name_strategy(topic_name, record_name)}" + + NAME_STRATEGIES = { - "topic_name": topic_name_strategy, - "record_name": record_name_strategy, - "topic_record_name": topic_record_name_strategy, + NameStrategy.topic_name: topic_name_strategy, + NameStrategy.record_name: record_name_strategy, + NameStrategy.topic_record_name: topic_record_name_strategy, + NameStrategy.no_validation: no_validation_strategy, } @@ -103,7 +110,7 @@ async def post_new_schema( raise SchemaRetrievalError(result.json()) return SchemaId(result.json()["id"]) - async def _get_schema_r( + async def _get_schema_recursive( self, subject: Subject, explored_schemas: Set[Tuple[Subject, Optional[ResolvedVersion]]], @@ -131,7 +138,7 @@ async def _get_schema_r( references = [Reference.from_dict(data) for data in json_result["references"]] dependencies = {} for reference in references: - _, schema, version = await self._get_schema_r(reference.subject, explored_schemas, reference.version) + _, schema, version = await self._get_schema_recursive(reference.subject, explored_schemas, reference.version) dependencies[reference.name] = Dependency( name=reference.name, subject=reference.subject, version=version, target_schema=schema ) @@ -174,7 +181,7 @@ async def get_schema( - ValidatedTypedSchema: The retrieved schema, validated and typed. - ResolvedVersion: The version of the schema that was retrieved. """ - return await self._get_schema_r(subject, set(), version) + return await self._get_schema_recursive(subject, set(), version) async def get_schema_for_id(self, schema_id: SchemaId) -> Tuple[TypedSchema, List[Subject]]: result = await self.client.get(f"schemas/ids/{schema_id}", params={"includeSubjects": "True"}) @@ -225,6 +232,25 @@ async def close(self): await self.client.close() +def get_subject_name( + topic_name: str, + schema: TypedSchema, + subject_type: SubjectType, + naming_strategy: NameStrategy, +) -> Subject: + namespace = "dummy" + if schema.schema_type is SchemaType.AVRO: + if isinstance(schema.schema, avro.schema.NamedSchema): + namespace = schema.schema.namespace or "" + if schema.schema_type is SchemaType.JSONSCHEMA: + namespace = schema.to_dict().get("namespace", "dummy") + # Protobuf does not use namespaces in terms of AVRO + if schema.schema_type is SchemaType.PROTOBUF: + namespace = "" + naming_strategy = NAME_STRATEGIES[naming_strategy] + return Subject(f"{naming_strategy(topic_name, namespace)}-{subject_type}") + + class SchemaRegistrySerializer: def __init__( self, @@ -243,36 +269,26 @@ def __init__( else: registry_url = f"http://{self.config['registry_host']}:{self.config['registry_port']}" registry_client = SchemaRegistryClient(registry_url, session_auth=session_auth) - name_strategy = config.get("name_strategy", "topic_name") - self.subject_name_strategy = NAME_STRATEGIES.get(name_strategy, topic_name_strategy) self.registry_client: Optional[SchemaRegistryClient] = registry_client self.ids_to_schemas: Dict[int, TypedSchema] = {} self.ids_to_subjects: MutableMapping[int, List[Subject]] = TTLCache(maxsize=10000, ttl=600) self.schemas_to_ids: Dict[str, SchemaId] = {} + self._topic_strategy_cache: MutableMapping[TopicName, NameStrategy] = TTLCache(maxsize=10000, ttl=600) async def close(self) -> None: if self.registry_client: await self.registry_client.close() self.registry_client = None - def get_subject_name( - self, - topic_name: str, - schema: TypedSchema, - subject_type: str, - schema_type: SchemaType, - ) -> Subject: - namespace = "dummy" - if schema_type is SchemaType.AVRO: - if isinstance(schema.schema, avro.schema.NamedSchema): - namespace = schema.schema.namespace - if schema_type is SchemaType.JSONSCHEMA: - namespace = schema.to_dict().get("namespace", "dummy") - # Protobuf does not use namespaces in terms of AVRO - if schema_type is SchemaType.PROTOBUF: - namespace = "" - - return Subject(f"{self.subject_name_strategy(topic_name, namespace)}-{subject_type}") + async def get_topic_strategy_name(self, topic_name: TopicName) -> NameStrategy: + assert self.registry_client, "must not call this method after the object is closed." + if topic_name in self._topic_strategy_cache: + return self._topic_strategy_cache[topic_name] + result = await self.registry_client.client.get(f"topic/{topic_name}/name_strategy") + + strategy = NameStrategy(result.json()["strategy"]) + self._topic_strategy_cache[topic_name] = strategy + return strategy async def get_schema_for_subject(self, subject: Subject) -> TypedSchema: assert self.registry_client, "must not call this method after the object is closed." diff --git a/tests/conftest.py b/tests/conftest.py index 3b903c699..99ba55809 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -160,7 +160,7 @@ def fixture_session_logdir(request, tmp_path_factory, worker_id) -> Path: @pytest.fixture(scope="session", name="default_config_path") -def fixture_default_config(session_logdir: Path) -> str: +def fixture_default_config(session_logdir: Path) -> Path: path = session_logdir / "karapace_config.json" content = json.dumps({"registry_host": "localhost", "registry_port": 8081}).encode() content_len = len(content) @@ -170,7 +170,7 @@ def fixture_default_config(session_logdir: Path) -> str: raise OSError(f"Writing config failed, tried to write {content_len} bytes, but only {written} were written") fp.flush() os.fsync(fp) - return str(path) + return path @pytest.fixture(name="tmp_file", scope="function") diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 1c16b3c2b..b589afd45 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -7,12 +7,12 @@ from _pytest.fixtures import SubRequest from aiohttp.pytest_plugin import AiohttpClient from aiohttp.test_utils import TestClient -from contextlib import closing, ExitStack +from contextlib import asynccontextmanager, closing, ExitStack from dataclasses import asdict from filelock import FileLock from kafka import KafkaProducer from karapace.client import Client -from karapace.config import Config, set_config_defaults, write_config +from karapace.config import Config, ConfigDefaults, set_config_defaults, write_config from karapace.kafka_rest_apis import KafkaRest, KafkaRestAdminClient from pathlib import Path from tests.conftest import KAFKA_VERSION @@ -29,7 +29,7 @@ from tests.integration.utils.synchronization import lock_path_for from tests.integration.utils.zookeeper import configure_and_start_zk from tests.utils import repeat_until_successful_request -from typing import AsyncIterator, Iterator, List, Optional +from typing import AsyncContextManager, AsyncIterator, Callable, Iterator, List, Optional from urllib.parse import urlparse import asyncio @@ -452,13 +452,14 @@ async def fixture_registry_async_pair( yield [server.endpoint.to_url() for server in endpoints] -@pytest.fixture(scope="function", name="registry_cluster") -async def fixture_registry_cluster( +@asynccontextmanager +async def _registry_cluster( request: SubRequest, loop: asyncio.AbstractEventLoop, # pylint: disable=unused-argument session_logdir: Path, kafka_servers: KafkaServers, port_range: PortRangeInclusive, + custom_values: Optional[ConfigDefaults] = None, ) -> AsyncIterator[RegistryDescription]: # Do not start a registry when the user provided an external service. Doing # so would cause this node to join the existing group and participate in @@ -476,12 +477,56 @@ async def fixture_registry_cluster( config_templates=[config], data_dir=session_logdir / _clear_test_name(request.node.name), port_range=port_range, + custom_values=custom_values, ) as servers: yield servers[0] -@pytest.fixture(scope="function", name="registry_async_client") -async def fixture_registry_async_client( +@pytest.fixture(scope="function", name="registry_cluster") +async def fixture_registry_cluster( + request: SubRequest, + loop: asyncio.AbstractEventLoop, + session_logdir: Path, + kafka_servers: KafkaServers, + port_range: PortRangeInclusive, + custom_values: Optional[ConfigDefaults] = None, +) -> AsyncIterator[RegistryDescription]: + async with _registry_cluster( + request, + loop, + session_logdir, + kafka_servers, + port_range, + custom_values, + ) as registry_description: + yield registry_description + + +@pytest.fixture(scope="function", name="registry_cluster_from_custom_config") +def fixture_registry_cluster_with_custom_config( + request: SubRequest, + loop: asyncio.AbstractEventLoop, + session_logdir: Path, + kafka_servers: KafkaServers, + port_range: PortRangeInclusive, +) -> Callable[[ConfigDefaults], AsyncContextManager[RegistryDescription]]: + @asynccontextmanager + async def registry_from_custom_config(config: ConfigDefaults) -> RegistryDescription: + async with _registry_cluster( + request, + loop, + session_logdir, + kafka_servers, + port_range, + config, + ) as registry_description: + yield registry_description + + return registry_from_custom_config + + +@asynccontextmanager +async def _registry_async_client( request: SubRequest, registry_cluster: RegistryDescription, loop: asyncio.AbstractEventLoop, # pylint: disable=unused-argument @@ -507,6 +552,35 @@ async def fixture_registry_async_client( await client.close() +@pytest.fixture(scope="function", name="registry_async_client") +async def fixture_registry_async_client( + request: SubRequest, + registry_cluster: RegistryDescription, + loop: asyncio.AbstractEventLoop, +) -> Client: + async with _registry_async_client( + request, + registry_cluster, + loop, + ) as client: + yield client + + +@pytest.fixture(scope="function", name="registry_async_client_from_custom_config") +def fixture_registry_async_client_custom_config( + request: SubRequest, + registry_cluster_from_custom_config: Callable[[ConfigDefaults], AsyncIterator[RegistryDescription]], + loop: asyncio.AbstractEventLoop, +) -> Callable[[ConfigDefaults], AsyncContextManager[Client]]: + @asynccontextmanager + async def client_from_custom_config(config: ConfigDefaults) -> Client: + async with registry_cluster_from_custom_config(config) as registry_description: + async with _registry_async_client(request, registry_description, loop) as client: + yield client + + return client_from_custom_config + + @pytest.fixture(scope="function", name="credentials_folder") def fixture_credentials_folder() -> str: integration_test_folder = os.path.dirname(__file__) diff --git a/tests/integration/test_rest.py b/tests/integration/test_rest.py index 9fec19285..8c8413e8a 100644 --- a/tests/integration/test_rest.py +++ b/tests/integration/test_rest.py @@ -7,7 +7,11 @@ from kafka import KafkaProducer from kafka.errors import UnknownTopicOrPartitionError from karapace.client import Client -from karapace.kafka_rest_apis import KafkaRest, KafkaRestAdminClient +from karapace.config import NameStrategy, SubjectType +from karapace.kafka_rest_apis import KafkaRest, KafkaRestAdminClient, SUBJECT_VALID_POSTFIX +from karapace.schema_models import ValidatedTypedSchema +from karapace.schema_type import SchemaType +from karapace.serialization import get_subject_name from karapace.version import __version__ from pytest import raises from tests.integration.conftest import REST_PRODUCER_MAX_REQUEST_BYTES @@ -26,6 +30,7 @@ import asyncio import base64 import json +import pytest import time NEW_TOPIC_TIMEOUT = 10 @@ -172,9 +177,9 @@ async def test_avro_publish( new_schema_id = res.json()["id"] # test checks schema id use for key and value, register schema for both with topic naming strategy - for pl_type in ["key", "value"]: + for pl_type in SUBJECT_VALID_POSTFIX: res = await registry_async_client.post( - f"subjects/{tn}-{pl_type}/versions", json={"schema": schema_avro_json_evolution} + f"subjects/{tn}-{pl_type.value}/versions", json={"schema": schema_avro_json_evolution} ) assert res.ok assert res.json()["id"] == new_schema_id @@ -651,6 +656,124 @@ async def test_publish_with_schema_id_of_another_subject_novalidation( assert res.status_code == 200 +@pytest.mark.parametrize( + "strategy", + ( + NameStrategy.topic_name, + NameStrategy.record_name, + NameStrategy.topic_record_name, + ), +) +async def test_produce_subjects_with_different_name_strategies( + rest_async_client: Client, + registry_async_client: Client, + admin_client: KafkaRestAdminClient, + strategy: NameStrategy, +) -> None: + topic_name = new_topic(admin_client) + + await wait_for_topics(rest_async_client, topic_names=[topic_name], timeout=NEW_TOPIC_TIMEOUT, sleep=1) + create_messages_url = f"/topics/{topic_name}" + + typed_schema = ValidatedTypedSchema.parse( + SchemaType.AVRO, + json.dumps( + { + "type": "record", + "name": "Schema1", + "fields": [ + { + "name": "name", + "type": "string", + }, + ], + } + ), + ) + + res = await registry_async_client.post(f"/topic/{topic_name}/name_strategy/{strategy}", json={}) + assert res.ok + + # without the right subject it should fail even if the schema it's correct + res = await registry_async_client.post( + "subjects/random_subject_name/versions", + json={"schema": str(typed_schema)}, + ) + assert res.status_code == 200 + random_subject_name_id = res.json()["id"] + + res = await rest_async_client.post( + create_messages_url, + json={"value_schema_id": random_subject_name_id, "records": [{"value": {"name": "Mr. Mustache"}}]}, + headers=REST_HEADERS["avro"], + ) + assert res.status_code == 422 + + # registering the required subject + subject_to_create = get_subject_name(topic_name, typed_schema, SubjectType.value, strategy) + + res = await registry_async_client.post( + f"subjects/{subject_to_create}/versions", + json={"schema": str(typed_schema)}, + ) + assert res.status_code == 200 + schema_id = res.json()["id"] + + # trying to produce with the subject correctly registered + res = await rest_async_client.post( + create_messages_url, + json={"value_schema_id": schema_id, "records": [{"value": {"name": "Mr. Mustache"}}]}, + headers=REST_HEADERS["avro"], + ) + assert res.status_code == 200 + + +async def test_can_produce_anything_with_no_validation_policy( + rest_async_client: Client, + registry_async_client: Client, + admin_client: KafkaRestAdminClient, +) -> None: + topic_name = new_topic(admin_client) + + await wait_for_topics(rest_async_client, topic_names=[topic_name], timeout=NEW_TOPIC_TIMEOUT, sleep=1) + + typed_schema = ValidatedTypedSchema.parse( + SchemaType.AVRO, + json.dumps( + { + "type": "record", + "name": "Schema1", + "fields": [ + { + "name": "name", + "type": "string", + }, + ], + } + ), + ) + + res = await registry_async_client.post(f"/topic/{topic_name}/name_strategy/{NameStrategy.no_validation}", json={}) + assert res.ok + + # with the no_validation strategy we can produce even if we use a totally random subject name + create_messages_url = f"/topics/{topic_name}" + + res = await registry_async_client.post( + "subjects/random_subject_name/versions", + json={"schema": str(typed_schema)}, + ) + assert res.status_code == 200 + random_subject_name_id = res.json()["id"] + + res = await rest_async_client.post( + create_messages_url, + json={"value_schema_id": random_subject_name_id, "records": [{"value": {"name": "Mr. Mustache"}}]}, + headers=REST_HEADERS["avro"], + ) + assert res.status_code == 200 + + async def test_brokers(rest_async_client: Client) -> None: res = await rest_async_client.get("/brokers") assert res.ok diff --git a/tests/integration/test_schema.py b/tests/integration/test_schema.py index 4e325a2a0..6c08d3e8b 100644 --- a/tests/integration/test_schema.py +++ b/tests/integration/test_schema.py @@ -7,6 +7,7 @@ from http import HTTPStatus from kafka import KafkaProducer from karapace.client import Client +from karapace.config import ConfigDefaults, NameStrategy from karapace.rapu import is_success from karapace.schema_registry_apis import SchemaErrorMessages from karapace.utils import json_encode @@ -18,7 +19,7 @@ create_subject_name_factory, repeat_until_successful_request, ) -from typing import List, Tuple +from typing import AsyncIterator, Callable, List, Tuple import asyncio import json @@ -1079,6 +1080,38 @@ async def assert_schema_versions_failed(client: Client, trail: str, schema_id: i assert res.status_code == response_code +@pytest.mark.parametrize( + "strategy", + ( + NameStrategy.topic_name, + NameStrategy.record_name, + NameStrategy.topic_record_name, + NameStrategy.no_validation, + ), +) +async def test_default_name_strategy_no_validation( + registry_async_client_from_custom_config: Callable[[ConfigDefaults], AsyncIterator[RegistryDescription]], + strategy: NameStrategy, +) -> None: + async with registry_async_client_from_custom_config({"default_name_strategy": strategy}) as registry_client: + res = await registry_client.get("/topic/foo/name_strategy") + assert res.ok + assert res.json() == {"strategy": strategy.value} + + +async def test_set_name_strategy(registry_async_client: Client) -> None: + res = await registry_async_client.get("/topic/foo/name_strategy") + assert res.ok + assert res.json() == {"strategy": NameStrategy.topic_name} + + res = await registry_async_client.post(f"/topic/foo/name_strategy/{NameStrategy.record_name}", json={}) + assert res.ok + + res = await registry_async_client.get("/topic/foo/name_strategy") + assert res.ok + assert res.json() == {"strategy": NameStrategy.record_name} + + async def register_schema(registry_async_client: Client, trail, subject: str, schema_str: str) -> Tuple[int, int]: # Register to get the id res = await registry_async_client.post( diff --git a/tests/integration/utils/cluster.py b/tests/integration/utils/cluster.py index 31c06e4bd..b82d869eb 100644 --- a/tests/integration/utils/cluster.py +++ b/tests/integration/utils/cluster.py @@ -4,12 +4,12 @@ """ from contextlib import asynccontextmanager, ExitStack from dataclasses import dataclass -from karapace.config import Config, set_config_defaults, write_config +from karapace.config import Config, ConfigDefaults, set_config_defaults, write_config from pathlib import Path from tests.integration.utils.network import PortRangeInclusive from tests.integration.utils.process import stop_process, wait_for_port_subprocess from tests.utils import new_random_name, popen_karapace_all -from typing import AsyncIterator, List +from typing import AsyncIterator, List, Optional @dataclass(frozen=True) @@ -33,6 +33,7 @@ async def start_schema_registry_cluster( config_templates: List[Config], data_dir: Path, port_range: PortRangeInclusive, + custom_values: Optional[ConfigDefaults] = None, ) -> AsyncIterator[List[RegistryDescription]]: """Start a cluster of schema registries, one process per `config_templates`.""" for template in config_templates: @@ -76,7 +77,14 @@ async def start_schema_registry_cluster( log_path = group_dir / f"{pos}.log" error_path = group_dir / f"{pos}.error" - config = set_config_defaults(config) + config = ( + set_config_defaults(config) + if custom_values is None + else { + **dict(item for item in set_config_defaults(config).items() if item[0] not in custom_values), + **custom_values, + } + ) write_config(config_path, config) logfile = stack.enter_context(open(log_path, "w")) diff --git a/tests/unit/test_protobuf_serialization.py b/tests/unit/test_protobuf_serialization.py index 3acd344b8..db039c64f 100644 --- a/tests/unit/test_protobuf_serialization.py +++ b/tests/unit/test_protobuf_serialization.py @@ -15,6 +15,7 @@ START_BYTE, ) from karapace.typing import ResolvedVersion, Subject +from pathlib import Path from tests.utils import schema_protobuf, test_fail_objects_protobuf, test_objects_protobuf from unittest.mock import call, Mock @@ -35,7 +36,7 @@ async def make_ser_deser(config_path: str, mock_client) -> SchemaRegistrySeriali return serializer -async def test_happy_flow(default_config_path): +async def test_happy_flow(default_config_path: Path): mock_protobuf_registry_client = Mock() schema_for_id_one_future = asyncio.Future() schema_for_id_one_future.set_result( @@ -61,7 +62,7 @@ async def test_happy_flow(default_config_path): assert mock_protobuf_registry_client.method_calls == [call.get_schema("top"), call.get_schema_for_id(1)] -async def test_happy_flow_references(default_config_path): +async def test_happy_flow_references(default_config_path: Path): no_ref_schema_str = """ |syntax = "proto3"; | @@ -129,7 +130,7 @@ async def test_happy_flow_references(default_config_path): assert mock_protobuf_registry_client.method_calls == [call.get_schema("top"), call.get_schema_for_id(1)] -async def test_happy_flow_references_two(default_config_path): +async def test_happy_flow_references_two(default_config_path: Path): no_ref_schema_str = """ |syntax = "proto3"; | @@ -216,7 +217,7 @@ async def test_happy_flow_references_two(default_config_path): assert mock_protobuf_registry_client.method_calls == [call.get_schema("top"), call.get_schema_for_id(1)] -async def test_serialization_fails(default_config_path): +async def test_serialization_fails(default_config_path: Path): mock_protobuf_registry_client = Mock() get_latest_schema_future = asyncio.Future() get_latest_schema_future.set_result( @@ -239,7 +240,7 @@ async def test_serialization_fails(default_config_path): assert mock_protobuf_registry_client.method_calls == [call.get_schema("top")] -async def test_deserialization_fails(default_config_path): +async def test_deserialization_fails(default_config_path: Path): mock_protobuf_registry_client = Mock() deserializer = await make_ser_deser(default_config_path, mock_protobuf_registry_client) @@ -258,7 +259,7 @@ async def test_deserialization_fails(default_config_path): assert mock_protobuf_registry_client.method_calls == [call.get_schema_for_id(500)] -async def test_deserialization_fails2(default_config_path): +async def test_deserialization_fails2(default_config_path: Path): mock_protobuf_registry_client = Mock() deserializer = await make_ser_deser(default_config_path, mock_protobuf_registry_client) diff --git a/tests/unit/test_serialization.py b/tests/unit/test_serialization.py index 029cae393..071f53f14 100644 --- a/tests/unit/test_serialization.py +++ b/tests/unit/test_serialization.py @@ -2,10 +2,13 @@ Copyright (c) 2023 Aiven Ltd See LICENSE for details """ -from karapace.config import DEFAULTS, read_config +from karapace.client import Path +from karapace.config import DEFAULTS, NameStrategy, read_config, SubjectType +from karapace.kafka_rest_apis import SUBJECT_VALID_POSTFIX from karapace.schema_models import SchemaType, ValidatedTypedSchema from karapace.serialization import ( flatten_unions, + get_subject_name, HEADER_FORMAT, InvalidMessageHeader, InvalidMessageSchema, @@ -29,6 +32,27 @@ log = logging.getLogger(__name__) +TYPED_SCHEMA = ValidatedTypedSchema.parse( + SchemaType.AVRO, + json.dumps( + { + "namespace": "io.aiven.data", + "name": "Test", + "type": "record", + "fields": [ + { + "name": "attr1", + "type": ["null", "string"], + }, + { + "name": "attr2", + "type": ["null", "string"], + }, + ], + } + ), +) + async def make_ser_deser(config_path: str, mock_client) -> SchemaRegistrySerializer: with open(config_path, encoding="utf8") as handler: @@ -39,7 +63,7 @@ async def make_ser_deser(config_path: str, mock_client) -> SchemaRegistrySeriali return serializer -async def test_happy_flow(default_config_path): +async def test_happy_flow(default_config_path: Path): mock_registry_client = Mock() get_latest_schema_future = asyncio.Future() get_latest_schema_future.set_result( @@ -62,32 +86,12 @@ async def test_happy_flow(default_config_path): def test_flatten_unions_record() -> None: - typed_schema = ValidatedTypedSchema.parse( - SchemaType.AVRO, - json.dumps( - { - "namespace": "io.aiven.data", - "name": "Test", - "type": "record", - "fields": [ - { - "name": "attr1", - "type": ["null", "string"], - }, - { - "name": "attr2", - "type": ["null", "string"], - }, - ], - } - ), - ) record = {"attr1": {"string": "sample data"}, "attr2": None} flatten_record = {"attr1": "sample data", "attr2": None} - assert flatten_unions(typed_schema.schema, record) == flatten_record + assert flatten_unions(TYPED_SCHEMA.schema, record) == flatten_record record = {"attr1": None, "attr2": None} - assert flatten_unions(typed_schema.schema, record) == record + assert flatten_unions(TYPED_SCHEMA.schema, record) == record def test_flatten_unions_array() -> None: @@ -248,7 +252,7 @@ def test_avro_json_write_accepts_json_encoded_data_without_tagged_unions() -> No assert buffer_a.getbuffer() == buffer_b.getbuffer() -async def test_serialization_fails(default_config_path): +async def test_serialization_fails(default_config_path: Path): mock_registry_client = Mock() get_latest_schema_future = asyncio.Future() get_latest_schema_future.set_result( @@ -264,7 +268,7 @@ async def test_serialization_fails(default_config_path): assert mock_registry_client.method_calls == [call.get_schema("topic")] -async def test_deserialization_fails(default_config_path): +async def test_deserialization_fails(default_config_path: Path): mock_registry_client = Mock() schema_for_id_one_future = asyncio.Future() schema_for_id_one_future.set_result((ValidatedTypedSchema.parse(SchemaType.AVRO, schema_avro_json), [Subject("stub")])) @@ -310,3 +314,31 @@ async def test_deserialization_fails(default_config_path): await deserializer.deserialize(enc_bytes) assert mock_registry_client.method_calls == [call.get_schema_for_id(1)] + + +@pytest.mark.parametrize( + "expected_subject,strategy,subject_type", + ( + (Subject("foo-key"), NameStrategy.topic_name, SUBJECT_VALID_POSTFIX[0]), + (Subject("io.aiven.data-key"), NameStrategy.record_name, SUBJECT_VALID_POSTFIX[0]), + (Subject("foo-io.aiven.data-key"), NameStrategy.topic_record_name, SUBJECT_VALID_POSTFIX[0]), + ( + Subject("__auto_registration_anonymous_foo-io.aiven.data-key"), + NameStrategy.no_validation, + SUBJECT_VALID_POSTFIX[0], + ), + (Subject("foo-value"), NameStrategy.topic_name, SUBJECT_VALID_POSTFIX[1]), + (Subject("io.aiven.data-value"), NameStrategy.record_name, SUBJECT_VALID_POSTFIX[1]), + (Subject("foo-io.aiven.data-value"), NameStrategy.topic_record_name, SUBJECT_VALID_POSTFIX[1]), + ( + Subject("__auto_registration_anonymous_foo-io.aiven.data-value"), + NameStrategy.no_validation, + SUBJECT_VALID_POSTFIX[1], + ), + ), +) +def test_name_strategy(expected_subject: Subject, strategy: NameStrategy, subject_type: SubjectType): + assert ( + get_subject_name(topic_name="foo", schema=TYPED_SCHEMA, subject_type=subject_type, naming_strategy=strategy) + == expected_subject + )