Skip to content

Commit

Permalink
Merge pull request #5 from generalpy101/add-expiration-support
Browse files Browse the repository at this point in the history
Added expiration support and other options support for SET and delete support
  • Loading branch information
generalpy101 authored Oct 15, 2023
2 parents 1d93493 + dfb9b7e commit 509d5be
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 26 deletions.
4 changes: 3 additions & 1 deletion redis_clone/redis_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def _parse_v2_client_request(self, data):
else:
raise Exception(f"Expected value for subargument {arg}, but none provided.")
else:
command_args.append((arg, None))
# Subargument does not take a value, so just append it to the command args.
# Adding True as a placeholder value to indicate that the subargument is present.
command_args.append((arg, True))

else:
command_args.append(arg)
Expand Down
30 changes: 21 additions & 9 deletions redis_clone/response_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,19 @@ def build_response(self, type, data=None):
else:
raise Exception("Protocol version not supported")

def _build_protocol_2_response(self, type, data):
def _build_protocol_2_response(self, data_type, data):
"""
Build response according to protocol version 2
"""
if type == Protocol_2_Data_Types.ERROR:
return self._build_protocol_2_error(data)
elif type == Protocol_2_Data_Types.SIMPLE_STRING:
return self._build_protocol_2_simple_string(data)
elif type == Protocol_2_Data_Types.BULK_STRING:
return self._build_protocol_2_bulk_string(data)
else:
raise Exception("Invalid protocol data type")
function_dict = {
Protocol_2_Data_Types.ERROR: self._build_protocol_2_error,
Protocol_2_Data_Types.SIMPLE_STRING: self._build_protocol_2_simple_string,
Protocol_2_Data_Types.BULK_STRING: self._build_protocol_2_bulk_string,
Protocol_2_Data_Types.INTEGER: self._build_protocol_2_integer,
}
if data_type not in function_dict:
raise Exception("Invalid response type")
return function_dict[data_type](data)

def _build_protocol_2_error(self, data):
"""
Expand Down Expand Up @@ -79,3 +80,14 @@ def _build_protocol_2_bulk_string(self, data):
data = b"$" + length + PROTOCOL_SEPARATOR + data.encode("utf-8") + PROTOCOL_SEPARATOR

return data

def _build_protocol_2_integer(self, data):
"""
Integers are used in order to represent whole numbers between -(2^63) and 2^63-1.
They are encoded in the following way:
:<data>\r\n
"""
# Syntax of integer is :<data>
# So data is second element after integer specifier
data = b":" + str(data).encode("utf-8") + PROTOCOL_SEPARATOR
return data
202 changes: 187 additions & 15 deletions redis_clone/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import socket
import time
import sys
import os
import asyncio
Expand Down Expand Up @@ -28,6 +28,43 @@ class Protocol_2_Commands(Enum):
DECR = "DECR"
PING = "PING"
ECHO = "ECHO"

class ExpiryValue:
def __init__(self, value, expiry_seconds=None, expiry_milliseconds=None, expiry_unix_timestamp_seconds=None, expiry_unix_timestamp_milliseconds=None) -> None:
self.value = value
self.expiry_seconds = time.time() + expiry_seconds if expiry_seconds else None
self.expiry_milliseconds = time.time() * 1000 + expiry_milliseconds if expiry_milliseconds else None
self.expiry_unix_timestamp_seconds = expiry_unix_timestamp_seconds
self.expiry_unix_timestamp_milliseconds = expiry_unix_timestamp_milliseconds

def get_value(self):
if self.expiry_milliseconds:
if self.expiry_milliseconds < int(time.time() * 1000):
return None
elif self.expiry_seconds:
if self.expiry_seconds < int(time.time()):
return None
elif self.expiry_unix_timestamp_milliseconds:
if self.expiry_unix_timestamp_milliseconds < int(time.time() * 1000):
return None
elif self.expiry_unix_timestamp_seconds:
if self.expiry_unix_timestamp_seconds < int(time.time()):
return None

return self.value

def get_expiry_seconds(self):
return self.expiry_seconds

def get_expiry_milliseconds(self):
return self.expiry_milliseconds

def get_expiry_unix_timestamp_seconds(self):
return self.expiry_unix_timestamp_seconds

def get_expiry_unix_timestamp_milliseconds(self):
return self.expiry_unix_timestamp_milliseconds



class RedisServer:
Expand Down Expand Up @@ -87,20 +124,8 @@ def _process_command(self, command_name, command_args) -> bytes:
Protocol_2_Data_Types.SIMPLE_STRING, " ".join(command_args)
)
elif command_name == Protocol_2_Commands.SET.value:
# Minimum 2 arguments required key and value
if len(command_args) < 2:
return self.response_builder.build_response(
Protocol_2_Data_Types.ERROR,
"ERR wrong number of arguments for 'SET' command",
)
key = command_args[0]
value = command_args[1]
return self._handle_set_command(command_args)

# Even if key exists, redis will overwrite the value
self.data_store[key] = value
return self.response_builder.build_response(
Protocol_2_Data_Types.SIMPLE_STRING, "OK"
)
elif command_name == Protocol_2_Commands.GET.value:
# Only 1 argument required key
if len(command_args) != 1:
Expand All @@ -109,19 +134,166 @@ def _process_command(self, command_name, command_args) -> bytes:
"ERR wrong number of arguments for 'GET' command",
)
key = command_args[0]
value = None
if key not in self.data_store:
return self.response_builder.build_response(
Protocol_2_Data_Types.BULK_STRING
)
else:
# Check the key is of type ExpiryValue
# This is to ensure uniformity in the when setting and getting values
if isinstance(self.data_store[key], ExpiryValue):
value = self.data_store[key].get_value()

if value is None:
self._delete_expired_key(key)

return self.response_builder.build_response(
Protocol_2_Data_Types.BULK_STRING, self.data_store[key]
Protocol_2_Data_Types.BULK_STRING, value
)

elif command_name == Protocol_2_Commands.DEL.value:
# Minimum 1 argument required key
if len(command_args) < 1:
return self.response_builder.build_response(
Protocol_2_Data_Types.ERROR,
"ERR wrong number of arguments for 'DEL' command",
)

keys_deleted = 0
for key in command_args:
if key in self.data_store:
del self.data_store[key]
keys_deleted += 1

return self.response_builder.build_response(
Protocol_2_Data_Types.INTEGER, keys_deleted
)


return self.response_builder.build_response(
Protocol_2_Data_Types.ERROR, "ERR unknown command '{}'".format(command_name)
)

def _handle_set_command(self, command_args):
# Minimum 2 arguments required key and value
if len(command_args) < 2:
return self.response_builder.build_response(
Protocol_2_Data_Types.ERROR,
"ERR wrong number of arguments for 'SET' command",
)
key = command_args[0]
value = command_args[1]

subarg_values = {
"EX": None, # seconds
"PX": None, # milliseconds
"EXAT": None, # unix timestamp in seconds
"PXAT": None, # unix timestamp in milliseconds
"KEEPTTL": None, # keep the ttl of the key boolean
"GET": None, # return the value of the key booelan
"NX": None, # set if key does not exist boolean
"XX": None, # set if key exists boolean
}

# Check set command has optional arguments
if len(command_args) > 2:
# Subargs are in format (arg, value)
for subarg in command_args[2:]:
subarg_values[subarg[0]] = subarg[1]

# Process subargs
# Check keepttl is not set with any other expiry subarg
if subarg_values["KEEPTTL"] and (subarg_values["EX"] or subarg_values["PX"] or subarg_values["EXAT"] or subarg_values["PXAT"]):
return self.response_builder.build_response(
Protocol_2_Data_Types.ERROR,
"ERR invalid expire command syntax",
)

# Return error if both NX and XX are set
if subarg_values["NX"] and subarg_values["XX"]:
return self.response_builder.build_response(
Protocol_2_Data_Types.ERROR,
"ERR XX and NX options at the same time are not compatible",
)

# Handle NX
# NX -- Only set the key if it does not already exist.
if subarg_values["NX"]:
if key in self.data_store:
return self.response_builder.build_response(
Protocol_2_Data_Types.BULK_STRING
)
else:
return self._assign_key_to_value(key, value, subarg_values)

# Handle XX
# XX -- Only set the key if it already exists.
if subarg_values["XX"]:
if key not in self.data_store:
return self.response_builder.build_response(
Protocol_2_Data_Types.BULK_STRING
)
else:
return self._assign_key_to_value(key, value, subarg_values)

# Handle GET
# GET -- Return the value of key
if subarg_values["GET"]:
if key in self.data_store:
return self.response_builder.build_response(
Protocol_2_Data_Types.BULK_STRING, self.data_store[key].get_value()
)
else:
return self.response_builder.build_response(
Protocol_2_Data_Types.BULK_STRING
)

# Handle KEEPTTL
# KEEPTTL -- Retain the time to live associated with the key.
if subarg_values["KEEPTTL"]:
if key in self.data_store:
self.data_store[key] = ExpiryValue(
value=value,
expiry_seconds=self.data_store[key].get_expiry_seconds(),
expiry_milliseconds=self.data_store[key].get_expiry_milliseconds(),
expiry_unix_timestamp_seconds=self.data_store[key].get_expiry_unix_timestamp_seconds(),
expiry_unix_timestamp_milliseconds=self.data_store[key].get_expiry_unix_timestamp_milliseconds(),
)

return self.response_builder.build_response(
Protocol_2_Data_Types.SIMPLE_STRING, "OK"
)
else:
return self.response_builder.build_response(
Protocol_2_Data_Types.BULK_STRING
)

# Normal case for set
return self._assign_key_to_value(key, value, subarg_values)

def _assign_key_to_value(self, key, value, subargs):
try:
self.data_store[key] = ExpiryValue(
value=value,
expiry_seconds=int(subargs["EX"]) if subargs["EX"] else None,
expiry_milliseconds=int(subargs["PX"]) if subargs["PX"] else None,
expiry_unix_timestamp_seconds=int(subargs["EXAT"]) if subargs["EXAT"] else None,
expiry_unix_timestamp_milliseconds=int(subargs["PXAT"]) if subargs["PXAT"] else None,
)
return self.response_builder.build_response(
Protocol_2_Data_Types.SIMPLE_STRING, "OK"
)
except ValueError:
return self.response_builder.build_response(
Protocol_2_Data_Types.ERROR,
"ERR value is not an integer or out of range",
)

def _delete_expired_key(self, key):
if key in self.data_store:
del self.data_store[key]

def stop(self):
logger.info("Stopping server...")
self.server.close()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_client_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_subargs_parsing(self):

print(args)
assert command == "SET"
assert args == ['mykey', 'myvalue', ('EX', '10'), ('NX', None)]
assert args == ['mykey', 'myvalue', ('EX', '10'), ('NX', True)]

def setup_method(self):
self.parser = Parser(protocol_version=2)

0 comments on commit 509d5be

Please sign in to comment.