From 1f40aa09bdd3ab465ab50aa7b168ebedb9ecf415 Mon Sep 17 00:00:00 2001 From: generalpy Date: Sun, 15 Oct 2023 13:56:02 +0530 Subject: [PATCH] formatted code --- redis_clone/redis_parser.py | 130 +++++++++++++++++--------------- redis_clone/response_builder.py | 67 ++++++++-------- redis_clone/server.py | 100 +++++++++++++++--------- setup.py | 17 ++--- tests/test_client_parser.py | 42 +++++------ tests/test_server.py | 5 ++ 6 files changed, 198 insertions(+), 163 deletions(-) diff --git a/redis_clone/redis_parser.py b/redis_clone/redis_parser.py index b7f7300..f2db47b 100644 --- a/redis_clone/redis_parser.py +++ b/redis_clone/redis_parser.py @@ -1,65 +1,70 @@ from enum import Enum -PROTOCOL_SEPARATOR = '\r\n' +PROTOCOL_SEPARATOR = "\r\n" + class Protocol_2_Data_Types(Enum): - ''' + """ For Simple Strings, the first byte of the reply is "+" For Errors, the first byte of the reply is "-" For Integers, the first byte of the reply is ":" For Bulk Strings, the first byte of the reply is "$" For Arrays, the first byte of the reply is "*" - ''' - SIMPLE_STRING = '+' - ERROR = '-' - INTEGER = ':' - BULK_STRING = '$' - ARRAY = '*' - + """ + + SIMPLE_STRING = "+" + ERROR = "-" + INTEGER = ":" + BULK_STRING = "$" + ARRAY = "*" + + class Parser: def __init__(self, protocol_version) -> None: self.protocol_version = protocol_version - + def parse_client_request(self, data): - ''' + """ This function parses the client request and returns the command name and arguments - ''' + """ if self.protocol_version == 2: return self._parse_v2_client_request(data) else: - raise Exception('Protocol version not supported') - + raise Exception("Protocol version not supported") + def _parse_v2_client_request(self, data): - ''' + """ Implementing the RESP2 protocol ref: https://redis.io/docs/reference/protocol-spec/#resp-versions Commands are Array of Bulk Strings Syntax for arrays is: *\r\n... Where each element has its own type specifier Syntax for Bulk Strings is: $\r\n\r\n Where length is the number of bytes in data - ''' + """ if not data: return None # Check if first byte is an array specifier else raise exception if data[0] != Protocol_2_Data_Types.ARRAY.value: - raise Exception('Invalid protocol data') - + raise Exception("Invalid protocol data") + # Split data according to separator of protocol # We'll split only once because we need to get number of elements in array command_items = data.split(PROTOCOL_SEPARATOR, 1) - + # Get number of elements in array # First item will be * rest should be number of elements num_elements = int(command_items[0][1:]) - + # Get command name # Syntax of command is ... # So command name is first element after array specifier # But we have both command name and arguments in the same array # But we also know that command will be like $\r\n\r\n... # We need first 2 elements after array specifier as full string for parsing command name as we'll use data parser - command_name = self.parse_data('\r\n'.join(command_items[1].split(PROTOCOL_SEPARATOR)[:2])) - + command_name = self.parse_data( + "\r\n".join(command_items[1].split(PROTOCOL_SEPARATOR)[:2]) + ) + # Get command arguments # Syntax of command is ... # So command arguments are elements after command name @@ -69,27 +74,28 @@ def _parse_v2_client_request(self, data): # For data parser, we need to 2 items each, length and data command_args = [] unparsed_args = [ - '\r\n'.join(command_items[1].split(PROTOCOL_SEPARATOR)[i:i+2]) for i in range(2, (num_elements * 2)-1, 2) + "\r\n".join(command_items[1].split(PROTOCOL_SEPARATOR)[i : i + 2]) + for i in range(2, (num_elements * 2) - 1, 2) ] - + for arg in unparsed_args: command_args.append(self.parse_data(arg)) - + return command_name, command_args - + def parse_data(self, data): - ''' + """ Parses normal redis data and returns the parsed to python data type - + Data format differs based on the type of data but general syntax is [data-specific-fields\r\n]\r\n - ''' + """ if self.protocol_version != 2: - raise Exception('Protocol version not supported') - + raise Exception("Protocol version not supported") + # Get first byte of data to determine type data_type = data[0] - + # Using simple if else ladder because data types are mutually exclusive if data_type == Protocol_2_Data_Types.SIMPLE_STRING.value: return self._parse_simple_string(data) @@ -102,98 +108,98 @@ def parse_data(self, data): elif data_type == Protocol_2_Data_Types.ARRAY.value: return self._parse_array(data) else: - raise Exception('Invalid protocol data') - + raise Exception("Invalid protocol data") + def _parse_simple_string(self, data): - ''' + """ Simple Strings are used to transmit non binary safe strings with minimal overhead. They are encoded in the following way: +\r\n - ''' + """ # Split data according to separator of protocol data_items = data.split(PROTOCOL_SEPARATOR) - + # Get data # Syntax of simple string is + # So data is second element after simple string specifier data = data_items[0][1:] - + return data - + def _parse_error(self, data): - ''' + """ Errors are used in order to signal client errors. They are encoded in the following way: -\r\n - ''' + """ # Split data according to separator of protocol data_items = data.split(PROTOCOL_SEPARATOR) - + # Get data # Syntax of error is - # So data is second element after error specifier data = data_items[0][1:] - + return data def _parse_integer(self, data): - ''' + """ Integers are used in order to transmit integers from the Redis server to the client. They are encoded in the following way: :[<+|->]\r\n An optional plus (+) or minus (-) as the sign. - ''' + """ # Split data according to separator of protocol data_items = data.split(PROTOCOL_SEPARATOR) - + # Get data # Syntax of integer is :[<+|->] # So data is second element after integer specifier data = data_items[0][1:] - + return int(data) - + def _parse_bulk_string(self, data): - ''' + """ Bulk Strings are used in order to represent a single binary safe string up to 512 MB in length. They are encoded in the following way: $\r\n\r\n Where length is the number of bytes in data - ''' - + """ + # Split data according to separator of protocol data_items = data.split(PROTOCOL_SEPARATOR) - + # Get length # Syntax of bulk string is $ # So length is second element after bulk string specifier length = int(data_items[0][1:]) - + # Get data # Syntax of bulk string is $\r\n\r\n # So data is third element after bulk string specifier data = data_items[1] - + # Check if length of data is same as length specified if len(data) != length: - raise Exception('Invalid protocol data') - + raise Exception("Invalid protocol data") + return data - + def _parse_array(self, data): - ''' + """ Arrays are used in order to represent a list of other RESP data types. They are encoded in the following way: *\r\n... Where each element has its own type specifier - ''' + """ # Split data according to separator of protocol data_items = data.split(PROTOCOL_SEPARATOR) - + # Get number of elements in array # First item will be * rest should be number of elements num_elements = int(data_items[0][1:]) - + # Get elements # Syntax of array is *\r\n... # So elements are from second element after array specifier to end @@ -201,5 +207,5 @@ def _parse_array(self, data): elements = [] for element in data_items[1:]: elements.append(self.parse_data(element)) - + return elements diff --git a/redis_clone/response_builder.py b/redis_clone/response_builder.py index 05b6c95..149ca9a 100644 --- a/redis_clone/response_builder.py +++ b/redis_clone/response_builder.py @@ -2,32 +2,33 @@ class ResponseBuilder: - ''' + """ Builds the response that will be sent to the client Data is encoded according to the Redis protocol and types are converted to bytes - ''' + """ + def __init__(self, protocol_version=2) -> None: self.protocol_version = protocol_version - + def respond_with_ok(self): - ''' + """ Respond with ok - ''' - return self.build_response('OK', Protocol_2_Data_Types.SIMPLE_STRING) - + """ + return self.build_response("OK", Protocol_2_Data_Types.SIMPLE_STRING) + def build_response(self, type, data=None): - ''' + """ Build response according to protocol version - ''' + """ if self.protocol_version == 2: return self._build_protocol_2_response(type, data) else: - raise Exception('Protocol version not supported') - + raise Exception("Protocol version not supported") + def _build_protocol_2_response(self, type, data): - ''' + """ Build response according to protocol version 2 - ''' + """ if type == Protocol_2_Data_Types.ERROR: return self._build_protocol_2_error(data) elif type == Protocol_2_Data_Types.SIMPLE_STRING: @@ -35,47 +36,47 @@ def _build_protocol_2_response(self, type, data): elif type == Protocol_2_Data_Types.BULK_STRING: return self._build_protocol_2_bulk_string(data) else: - raise Exception('Invalid protocol data type') - + raise Exception("Invalid protocol data type") + def _build_protocol_2_error(self, data): - ''' + """ Errors are used in order to signal client errors. They are encoded in the following way: -\r\n - ''' + """ # Syntax of error is - # So data is second element after error specifier - data = f'-{data}{PROTOCOL_SEPARATOR}' - - return data.encode('utf-8') + data = f"-{data}{PROTOCOL_SEPARATOR}" + + return data.encode("utf-8") def _build_protocol_2_simple_string(self, data): - ''' + """ Simple Strings are used to transmit non binary safe strings with minimal overhead. They are encoded in the following way: +\r\n - ''' + """ # Syntax of simple string is + # So data is second element after simple string specifier - data = f'+{data}{PROTOCOL_SEPARATOR}' - - return data.encode('utf-8') - + data = f"+{data}{PROTOCOL_SEPARATOR}" + + return data.encode("utf-8") + def _build_protocol_2_bulk_string(self, data): - ''' + """ Bulk Strings are used in order to represent a single binary safe string up to 512 MB in length. They are encoded in the following way: $\r\n For example, "foobar" is encoded as "$6\r\nfoobar\r\n". For nil values bulk strings are encoded with $-1\r\n - ''' - + """ + # If data is None then return nil value if data is None: - return f'${-1}{PROTOCOL_SEPARATOR}'.encode('utf-8') + return f"${-1}{PROTOCOL_SEPARATOR}".encode("utf-8") else: # Syntax of bulk string is $ # So data is second element after bulk string specifier - data = f'${len(data)}{PROTOCOL_SEPARATOR}{data}{PROTOCOL_SEPARATOR}' - - return data.encode('utf-8') + data = f"${len(data)}{PROTOCOL_SEPARATOR}{data}{PROTOCOL_SEPARATOR}" + + return data.encode("utf-8") diff --git a/redis_clone/server.py b/redis_clone/server.py index 0748e71..e360117 100644 --- a/redis_clone/server.py +++ b/redis_clone/server.py @@ -11,22 +11,24 @@ logger = logging.getLogger(__name__) -HOST = os.environ.get('REDIS_HOST', '0.0.0.0') -PORT = os.environ.get('REDIS_PORT', 9999) +HOST = os.environ.get("REDIS_HOST", "0.0.0.0") +PORT = os.environ.get("REDIS_PORT", 9999) + class Protocol_2_Commands(Enum): - ''' + """ Some common redis commands - ''' - SET = 'SET' - GET = 'GET' - DEL = 'DEL' - EXISTS = 'EXISTS' - INCR = 'INCR' - DECR = 'DECR' - PING = 'PING' - ECHO = 'ECHO' - + """ + + SET = "SET" + GET = "GET" + DEL = "DEL" + EXISTS = "EXISTS" + INCR = "INCR" + DECR = "DECR" + PING = "PING" + ECHO = "ECHO" + class RedisServer: def __init__(self, host, port) -> None: @@ -36,15 +38,17 @@ def __init__(self, host, port) -> None: self.response_builder = ResponseBuilder(protocol_version=2) self.data_store = {} self.running = False - + async def start(self): - logger.info('Starting server...') - self.server = await asyncio.start_server(self._handle_connection, self.host, self.port) + logger.info("Starting server...") + self.server = await asyncio.start_server( + self._handle_connection, self.host, self.port + ) async with self.server: await self.server.serve_forever() async def _handle_connection(self, reader, writer): - addr = writer.get_extra_info('peername') + addr = writer.get_extra_info("peername") logger.info(f"Connection established with {addr}") while True: @@ -54,12 +58,12 @@ async def _handle_connection(self, reader, writer): logger.info(f"Received data: {data}") # Convert bytes to string - data = data.decode('utf-8') + data = data.decode("utf-8") command_name, command_args = self.parser.parse_client_request(data) - logger.info(f'Command name: {command_name}') - logger.info(f'Command args: {command_args}') + logger.info(f"Command name: {command_name}") + logger.info(f"Command args: {command_args}") response = self._process_command(command_name, command_args) - logger.info(f'Response: {response}') + logger.info(f"Response: {response}") writer.write(response) await writer.drain() @@ -71,39 +75,61 @@ def _process_command(self, command_name, command_args) -> bytes: # Convert command name to uppercase command_name = command_name.upper() if command_name == Protocol_2_Commands.PING.value: - return self.response_builder.build_response(Protocol_2_Data_Types.SIMPLE_STRING, 'PONG') + return self.response_builder.build_response( + Protocol_2_Data_Types.SIMPLE_STRING, "PONG" + ) elif command_name == Protocol_2_Commands.ECHO.value: # Echo command returns the same string if len(command_args) == 0: - return self.response_builder.build_response(Protocol_2_Data_Types.ERROR, 'ERR wrong number of arguments for \'ECHO\' command') - return self.response_builder.build_response(Protocol_2_Data_Types.SIMPLE_STRING, " ".join(command_args)) + return self.response_builder.build_response( + Protocol_2_Data_Types.ERROR, + "ERR wrong number of arguments for 'ECHO' command", + ) + return self.response_builder.build_response( + Protocol_2_Data_Types.SIMPLE_STRING, " ".join(command_args) + ) elif command_name == Protocol_2_Commands.SET.value: # Minimum 2 arguments required key and value if len(command_args) < 2: - return self.response_builder.build_response(Protocol_2_Data_Types.ERROR, 'ERR wrong number of arguments for \'SET\' command') + return self.response_builder.build_response( + Protocol_2_Data_Types.ERROR, + "ERR wrong number of arguments for 'SET' command", + ) key = command_args[0] value = command_args[1] - + # Even if key exists, redis will overwrite the value self.data_store[key] = value - return self.response_builder.build_response(Protocol_2_Data_Types.SIMPLE_STRING, 'OK') + return self.response_builder.build_response( + Protocol_2_Data_Types.SIMPLE_STRING, "OK" + ) elif command_name == Protocol_2_Commands.GET.value: # Only 1 argument required key if len(command_args) != 1: - return self.response_builder.build_response(Protocol_2_Data_Types.ERROR, 'ERR wrong number of arguments for \'GET\' command') + return self.response_builder.build_response( + Protocol_2_Data_Types.ERROR, + "ERR wrong number of arguments for 'GET' command", + ) key = command_args[0] if key not in self.data_store: - return self.response_builder.build_response(Protocol_2_Data_Types.BULK_STRING) - - return self.response_builder.build_response(Protocol_2_Data_Types.BULK_STRING, self.data_store[key]) - - return self.response_builder.build_response(Protocol_2_Data_Types.ERROR, 'ERR unknown command \'{}\''.format(command_name)) - + return self.response_builder.build_response( + Protocol_2_Data_Types.BULK_STRING + ) + + return self.response_builder.build_response( + Protocol_2_Data_Types.BULK_STRING, self.data_store[key] + ) + + return self.response_builder.build_response( + Protocol_2_Data_Types.ERROR, "ERR unknown command '{}'".format(command_name) + ) + def stop(self): - logger.info('Stopping server...') + logger.info("Stopping server...") self.server.close() -if __name__ == '__main__': + +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) server = RedisServer(host=HOST, port=PORT) - asyncio.run(server.start()) \ No newline at end of file + asyncio.run(server.start()) diff --git a/setup.py b/setup.py index 7a3bfe1..ac475d9 100644 --- a/setup.py +++ b/setup.py @@ -1,16 +1,16 @@ # Setup file for redis_clone -''' +""" Current Dev Dependencies: - pytest - redis -''' +""" import setuptools with open("README.md", "r") as fh: long_description = fh.read() - + setuptools.setup( name="redis_clone", version="0.0.1", @@ -23,14 +23,11 @@ packages=setuptools.find_packages(), classifiers=[ "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License" + "License :: OSI Approved :: MIT License", ], - python_requires='>=3.6', + python_requires=">=3.6", # Dev dependencies extras_require={ - 'dev': [ - 'pytest', - 'redis' - ], + "dev": ["pytest", "redis"], }, -) \ No newline at end of file +) diff --git a/tests/test_client_parser.py b/tests/test_client_parser.py index 59061ef..4aa1df6 100644 --- a/tests/test_client_parser.py +++ b/tests/test_client_parser.py @@ -1,42 +1,42 @@ # Using pytest for tests from redis_clone.redis_parser import Parser, Protocol_2_Data_Types + class TestParserClient: - def test_initial_command_request(self): - ''' + """ Test initial COMMAND request - ''' + """ parser = Parser(protocol_version=2) - + # Test initial connection - test_str = '*1\r\n$7\r\nCOMMAND\r\n' + test_str = "*1\r\n$7\r\nCOMMAND\r\n" command, args = self.parser.parse_client_request(test_str) - - assert command == 'COMMAND' + + assert command == "COMMAND" assert args == [] - + def test_set_command_request(self): - ''' + """ Test SET command request - ''' + """ # Test initial connection - test_str = '*3\r\n$3\r\nSET\r\n$5\r\nmykey\r\n$7\r\nmyvalue\r\n' + test_str = "*3\r\n$3\r\nSET\r\n$5\r\nmykey\r\n$7\r\nmyvalue\r\n" command, args = self.parser.parse_client_request(test_str) - assert command == 'SET' - assert args == ['mykey', 'myvalue'] - + assert command == "SET" + assert args == ["mykey", "myvalue"] + def test_get_command_request(self): - ''' + """ Test GET command request - ''' + """ # Test initial connection - test_str = '*2\r\n$3\r\nGET\r\n$5\r\nmykey\r\n' + test_str = "*2\r\n$3\r\nGET\r\n$5\r\nmykey\r\n" command, args = self.parser.parse_client_request(test_str) - assert command == 'GET' - assert args == ['mykey'] - + assert command == "GET" + assert args == ["mykey"] + def setup_method(self): - self.parser = Parser(protocol_version=2) \ No newline at end of file + self.parser = Parser(protocol_version=2) diff --git a/tests/test_server.py b/tests/test_server.py index 4eeee12..a93633a 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -5,6 +5,7 @@ REDIS_HOST = os.environ.get("REDIS_HOST", "localhost") REDIS_PORT = os.environ.get("REDIS_PORT", 9999) + # Server should be running before running the tests @pytest.fixture(scope="function") def client(): @@ -12,15 +13,18 @@ def client(): yield r r.close() + def test_ping(client): response = client.ping() print(response) assert response == True + def test_echo(client): response = client.echo("Hello World") assert response == "Hello World" + def test_set_get(client): response = client.set("test_key", "test_value") assert response == True @@ -28,6 +32,7 @@ def test_set_get(client): value = client.get("test_key") assert value == "test_value" + def test_nonexistent_get(client): value = client.get("random") assert value is None