Writing a Redis clone in Go from scratch
In this post we're going to write a basic Redis clone in Go that implements the most simple commands: GET,
+SET, DEL and QUIT. At the end you'll know how to parse a byte stream from a live TCP connection, and hopefully have a working
+implementation of Redis.
What's intersting about this project is that it's production ready (not really).
+ It's being used in production in an old Web app that I made for a client in 2017.
+ It has been running for a few months now without issues.
I mantain that app to this day and I charge like 50 bucks a month for it. I do it because
+ Im friends with the person that uses the app.
Long story short, the app's backend is written in PHP and uses Redis for caching, only GET, SET and DEL commands.
+ I asked my friend if I could replace it with my custom version and said yes, so I decided to give it a go.
If you're looking for C/C++ implementation, go check out this book.
What we'll be building
If you go to the
+command list on redis webpage you'll see that there are 463 commands to this day (maybe more if you're in the future).
+
That's a crazy number. Here, we're only implementing 4 commands: GET, SET, DEL, QUIT,
+ the other 459 commands are left as an exercise to the reader.
+
GET
GET key
+
Returns the value referenced by key.
+ If the key does not exist then nil is returned.
+
SET
SET command gains more features on newer versions of Redis. We're going to implement one that has all features
+ that were realeased up until version 6.0.0.
+
SET key value [NX | XX] [EX seconds | PX milliseconds]
+
Stores value as a string that is referenced by key.
+ Overwrites any data that was previously referenced by the key.
+
Options
- EX seconds -- Set the specified expire time, in seconds.
- PX milliseconds -- Set the specified expire time, in seconds.
- NX -- Only set the key if it does not already exist.
- XX -- Only set the key if it already exist.
DEL
DEL key [key ...]
+
Takes 'any' amount of keys as input and removes all of them from storage. If a key doesn't exist it is ignored.
+ Returns the amount of keys that were deleted.
+
QUIT
QUIT
+
When receiving this command the server closes the connection. It's useful for interactive sessions.
+ For production environments the client should close the connection without sending any commands.
+
Examples
Let's start an interactive session of redis to test some commands.
+ We can install redis-server with docker and run it locally.
+ Then we can use telnet to connect directly via TCP.
+ Open a terminal and execute the following instructions:
+
$ docker run -d --name redis-server -p 6379:6379 redis:alpine
+
+$ telnet 127.0.0.1 6379
+Trying 127.0.0.1...
+Connected to localhost.
+Escape character is '^]'.
+
At this point the prompt should be waiting for you to write something. We're gonna test a couple of commands.
+ In the code boxes below the first line is the command, following lines are the response.
+
GET a
+$-1
+
^ That weird $-1 is the special nil value. Which means there's nothing stored here.
+set a 1
++OK
+
^ First thing to notice here is that we can use lowercase version of SET.
+ Also, when the command is successful returns +OK.
+set b 2
++OK
+
SET c 3
++OK
+
^ Just storing a couple more values.
+GET a
+$1
+1
+
^ Here the response is returned in two lines. First line is the length of the string. Second line
+ is the actual string.
+get b
+$1
+2
+
^ We can also use lowercase version of GET, I bet commands are case-insensitive.
+get C
+$-1
+
^ Testing with uppercase C gives a nil. Keys seem to be case-sensitive, probably values too.
+ That makes sense.
+del a b c
+:3
+
^ Deleting everything returns the amount of keys deleted. Integers are indicated by ':'.
+quit
++OK
+Connection closed by foreign host.
+
^ When we send QUIT, the server closes the connection and we're back to our terminal session.
+With those tests we have enough information to start building. We learned a little bit about the
+ redis protocol and what the responses look like.
+
Sending commands
Until now we've been using the inline version of redis command.
+ There's another kind that follows the
+RESP (Redis serialization protocol).
The RESP protocol is quite similar to what we've seen in the examples above.
+The most important addition is arrays. Let's see a Client<>Server interaction
+ using arrays.
+
Client
*2
+$3
+GET
+$1
+a
+
Server$-1
+
The server response looks the same as in the inline version.
+ But what the client sends looks very different:
+- In this case, the first thing the client sends is '*' followed by the number of elements in the array,
+ so '*2' indicates that there are 2 elements in the array and they would be found in the following lines.
+
- After that we have '$3' which means we're expecting the first element to be a string of length 3.
+ Next line is the actual string, in our case is the command 'GET'.
+
- The next value is also a string and is the key passed to the command.
+
That's almost everything we need to start building a client. There's one last thing: error responses.
+
-Example error message
+-ERR unknown command 'foo'
+-WRONGTYPE Operation against a key holding the wrong kind of value
+
A response that starts with a '-' is considered an error. The first word is the error type.
+ We'll only gonna be using 'ERR' as a generic response.
+RESP protocol is what client libraries use to communicate with Redis.
+ With all that in our toolbox we're ready to start building.
+
Receiving connections
A crucial part of our serve is the ability to receive client's information.
+ The way that this is done is that the server listens on a TCP port and waits
+ for client connections. Let's start building the basic structure.
+
Create a new go module, open main.go and create a main function as follows.
+
package main
+
+import (
+ "bufio"
+ "fmt"
+ "log"
+ "net"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+)
+
+var cache sync.Map
+
+func main() {
+ listener, err := net.Listen("tcp", ":6380")
+ if err != nil {
+ log.Fatal(err)
+ }
+ log.Println("Listening on tcp://0.0.0.0:6380")
+
+ for {
+ conn, err := listener.Accept()
+ log.Println("New connection", conn)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ go startSession(conn)
+ }
+}
+
After declaring the package and imports, we create a global sync.Map that would be our cache.
+ That's where keys are gonna be stored and retrieved.
+
On the main function we start listening on port 6380. After that we have an infinite loop that accepts
+ new connections and spawns a goroutine to handle the session.
+
Session handling
// startSession handles the client's session. Parses and executes commands and writes
+// responses back to the client.
+func startSession(conn net.Conn) {
+ defer func() {
+ log.Println("Closing connection", conn)
+ conn.Close()
+ }()
+ defer func() {
+ if err := recover(); err != nil {
+ log.Println("Recovering from error", err)
+ }
+ }()
+ p := NewParser(conn)
+ for {
+ cmd, err := p.command()
+ if err != nil {
+ log.Println("Error", err)
+ conn.Write([]uint8("-ERR " + err.Error() + "\r\n"))
+ break
+ }
+ if !cmd.handle() {
+ break
+ }
+ }
+}
+
It's super important that we close the connection when things are done. That's why we set a deferred function,
+ to close the connection when the session finishes.
+
After that we handle any panics using recover. We do this mainly because at some point we might be reading from
+ a connection that was closed by the client. And we don't want the entire server to die in case of an error.
+
Then we create a new parser and start trying to parse commands from the live connection. If we encounter an error
+ we write the error message back to the client and we finish the session.
+
When cmd.handle() returns false (signaling end of session) we break the loop and the session finishes.
+
Parsing commands
Basic parser structure:
+
// Parser contains the logic to read from a raw tcp connection and parse commands.
+type Parser struct {
+ conn net.Conn
+ r *bufio.Reader
+ // Used for inline parsing
+ line []byte
+ pos int
+}
+
+// NewParser returns a new Parser that reads from the given connection.
+func NewParser(conn net.Conn) *Parser {
+ return &Parser{
+ conn: conn,
+ r: bufio.NewReader(conn),
+ line: make([]byte, 0),
+ pos: 0,
+ }
+}
+
This is pretty straight-forward. We store a reference to the connection, a reader and then some
+ attributes that will help us with parsing.
+
The NewParser() function should be used as a contructor for Parser objects.
+
We need some helper functions that will make parsing easier:
+
func (p *Parser) current() byte {
+ if p.atEnd() {
+ return '\r'
+ }
+ return p.line[p.pos]
+}
+
+func (p *Parser) advance() {
+ p.pos++
+}
+
+func (p *Parser) atEnd() bool {
+ return p.pos >= len(p.line)
+}
+
+func (p *Parser) readLine() ([]byte, error) {
+ line, err := p.r.ReadBytes('\r')
+ if err != nil {
+ return nil, err
+ }
+ if _, err := p.r.ReadByte(); err != nil {
+ return nil, err
+ }
+ return line[:len(line)-1], nil
+}
+
Also quite simple.
+
- current(): Returns the character being pointed at by pos inside the line.
- advance(): Point to the next character in the line.
- atEnd(): Indicates if we're at the end of the line.
- readLine(): Reads the input from the connection up to the carriage return char. Skips the '\n' char.
Parsing strings
In Redis we can send commands like so:
+
SET text "quoted \"text\" here"
+
This means we need a way to handle \, " chars inside a string.
+
For that we need a special parsing function that will handle strings:
+
// consumeString reads a string argument from the current line.
+func (p *Parser) consumeString() (s []byte, err error) {
+ for p.current() != '"' && !p.atEnd() {
+ cur := p.current()
+ p.advance()
+ next := p.current()
+ if cur == '\\' && next == '"' {
+ s = append(s, '"')
+ p.advance()
+ } else {
+ s = append(s, cur)
+ }
+ }
+ if p.current() != '"' {
+ return nil, errors.New("unbalanced quotes in request")
+ }
+ p.advance()
+ return
+}
+
From the functions that we've declared up to this point it's pretty clear that our parser
+ will be reading the input line by line. And the consuming the line one char at a time.
+
The way consumeString() works is quite tricky.
+ It assumes that the initial " has been consumed before entering the function.
+ And it consumes all characters in the current line up until it reaches the closing quotes character
+ or the end of the line.
+
Inside the loop we can see that we're reading the current character and advancing the pointer, then
+ the next character.
+ When the user is sending an escaped quote inside the string we detect that by checking the current
+ and the next characters.
+ In this special case we end up advancing the pointer twice. Because we consumed two: chars the backslash
+ and the quote. But we added only one char to the output: ".
+
We append all other characters to the output buffer.
+
When the loop finishes, if we're not pointing to the end quote char, that means that the user
+ sent an invalid command and we return an error.
+
Otherwise we advance the pointer and return normally.
+
Parsing commands
// command parses and returns a Command.
+func (p *Parser) command() (Command, error) {
+ b, err := p.r.ReadByte()
+ if err != nil {
+ return Command{}, err
+ }
+ if b == '*' {
+ log.Println("resp array")
+ return p.respArray()
+ } else {
+ line, err := p.readLine()
+ if err != nil {
+ return Command{}, err
+ }
+ p.pos = 0
+ p.line = append([]byte{}, b)
+ p.line = append(p.line, line...)
+ return p.inline()
+ }
+}
+
We read the first character sent by the client. If it's an asterisk we handle it
+ using the RESP protocol. Otherwise we assume that it's an inline command.
+
Let's start by parsing the inline commands first.
+
// Command implements the behavior of the commands.
+type Command struct {
+ args []string
+ conn net.Conn
+}
+
+// inline parses an inline message and returns a Command. Returns an error when there's
+// a problem reading from the connection or parsing the command.
+func (p *Parser) inline() (Command, error) {
+ // skip initial whitespace if any
+ for p.current() == ' ' {
+ p.advance()
+ }
+ cmd := Command{conn: p.conn}
+ for !p.atEnd() {
+ arg, err := p.consumeArg()
+ if err != nil {
+ return cmd, err
+ }
+ if arg != "" {
+ cmd.args = append(cmd.args, arg)
+ }
+ }
+ return cmd, nil
+}
+
This is also quite easy to skim through. We skip any leading whitespace
+ in case the user sent something like ' GET a'.
+
We create a new Command object with a reference to the session connection.
+
While we're not at the end of the line we consume args and append them to the
+ arg list of the command object if they are not empty.
+
Consuming arguments
// consumeArg reads an argument from the current line.
+func (p *Parser) consumeArg() (s string, err error) {
+ for p.current() == ' ' {
+ p.advance()
+ }
+ if p.current() == '"' {
+ p.advance()
+ buf, err := p.consumeString()
+ return string(buf), err
+ }
+ for !p.atEnd() && p.current() != ' ' && p.current() != '\r' {
+ s += string(p.current())
+ p.advance()
+ }
+ return
+}
+
Same as before we consume any leading whitespace.
+
If we find a quoted string we call our function from before: consumeString().
+
We append all characters to the output until we reach a carriage return \r, a whitespace
+ or the end of the line.
+
Parsing RESP protocol
// respArray parses a RESP array and returns a Command. Returns an error when there's
+// a problem reading from the connection.
+func (p *Parser) respArray() (Command, error) {
+ cmd := Command{}
+ elementsStr, err := p.readLine()
+ if err != nil {
+ return cmd, err
+ }
+ elements, _ := strconv.Atoi(string(elementsStr))
+ log.Println("Elements", elements)
+ for i := 0; i < elements; i++ {
+ tp, err := p.r.ReadByte()
+ if err != nil {
+ return cmd, err
+ }
+ switch tp {
+ case ':':
+ arg, err := p.readLine()
+ if err != nil {
+ return cmd, err
+ }
+ cmd.args = append(cmd.args, string(arg))
+ case '$':
+ arg, err := p.readLine()
+ if err != nil {
+ return cmd, err
+ }
+ length, _ := strconv.Atoi(string(arg))
+ text := make([]byte, 0)
+ for i := 0; len(text) <= length; i++ {
+ line, err := p.readLine()
+ if err != nil {
+ return cmd, err
+ }
+ text = append(text, line...)
+ }
+ cmd.args = append(cmd.args, string(text[:length]))
+ case '*':
+ next, err := p.respArray()
+ if err != nil {
+ return cmd, err
+ }
+ cmd.args = append(cmd.args, next.args...)
+ }
+ }
+ return cmd, nil
+}
+
As we know, the leading asterisk has already been consumed from the connection input.
+ So, at this point, the first line contains the number of elements to be consumed.
+ We read that into an integer.
+
We create a for loop with that will parse all the elements in the array.
+ We consume the first character to detect which kind of element we need to consume: int, string or array.
+
The int case is quite simple, we just read until the rest of the line.
+
The array case is also quite simple, we call respArray() and append the args of the result,
+ to the current command object.
+
For strings we read the first line and get the size of the string.
+ We keep reading lines until we have read the indicated amount of characters.
+
Handling commands
This is the 'fun' part of the implementation. Were our server becomes alive.
+ In this section we'll implement the actual functionality of the commands.
+
Let's start with the cmd.handle() function that we saw in handleSession().
+
// handle Executes the command and writes the response. Returns false when the connection should be closed.
+func (cmd Command) handle() bool {
+ switch strings.ToUpper(cmd.args[0]) {
+ case "GET":
+ return cmd.get()
+ case "SET":
+ return cmd.set()
+ case "DEL":
+ return cmd.del()
+ case "QUIT":
+ return cmd.quit()
+ default:
+ log.Println("Command not supported", cmd.args[0])
+ cmd.conn.Write([]uint8("-ERR unknown command '" + cmd.args[0] + "'\r\n"))
+ }
+ return true
+}
+
Needs no further explanation. Let's implement the easiest command: QUIT.
+
// quit Used in interactive/inline mode, instructs the server to terminate the connection.
+func (cmd *Command) quit() bool {
+ if len(cmd.args) != 1 {
+ cmd.conn.Write([]uint8("-ERR wrong number of arguments for '" + cmd.args[0] + "' command\r\n"))
+ return true
+ }
+ log.Println("Handle QUIT")
+ cmd.conn.Write([]uint8("+OK\r\n"))
+ return false
+}
+
If any extra arguments were passed to QUIT, it returns an error.
+
Otherwise write +OK to the client and return false.
+ Which if you remember handleSession() is the value to indicate that the session has finished.
+ After that the connection will be automatically closed.
+
The next easieast command is DEL
+
// del Deletes a key from the cache.
+func (cmd *Command) del() bool {
+ count := 0
+ for _, k := range cmd.args[1:] {
+ if _, ok := cache.LoadAndDelete(k); ok {
+ count++
+ }
+ }
+ cmd.conn.Write([]uint8(fmt.Sprintf(":%d\r\n", count)))
+ return true
+}
+
Iterates through all the keys passed, deletes the ones that exists and
+ writes back to the client the amount of keys deleted.
+
Returns true, which means the connection is kept alive.
+
Handling GET
// get Fetches a key from the cache if exists.
+func (cmd Command) get() bool {
+ if len(cmd.args) != 2 {
+ cmd.conn.Write([]uint8("-ERR wrong number of arguments for '" + cmd.args[0] + "' command\r\n"))
+ return true
+ }
+ log.Println("Handle GET")
+ val, _ := cache.Load(cmd.args[1])
+ if val != nil {
+ res, _ := val.(string)
+ if strings.HasPrefix(res, "\"") {
+ res, _ = strconv.Unquote(res)
+ }
+ log.Println("Response length", len(res))
+ cmd.conn.Write([]uint8(fmt.Sprintf("$%d\r\n", len(res))))
+ cmd.conn.Write(append([]uint8(res), []uint8("\r\n")...))
+ } else {
+ cmd.conn.Write([]uint8("$-1\r\n"))
+ }
+ return true
+}
+
As before, we validate that the correct number of arguments were passed to the command.
+
We load the value from the global variable cache.
+
If the value is nil we write back to the client the special $-1.
+
When we have a value we cast it as string and unquote it in case it's quoted.
+ Then we write the length as the first line of the response and the string as the
+ second line of the response.
+
Handling SET
This is the most complicated command that we'll implement.
+
// set Stores a key and value on the cache. Optionally sets expiration on the key.
+func (cmd Command) set() bool {
+ if len(cmd.args) < 3 || len(cmd.args) > 6 {
+ cmd.conn.Write([]uint8("-ERR wrong number of arguments for '" + cmd.args[0] + "' command\r\n"))
+ return true
+ }
+ log.Println("Handle SET")
+ log.Println("Value length", len(cmd.args[2]))
+ if len(cmd.args) > 3 {
+ pos := 3
+ option := strings.ToUpper(cmd.args[pos])
+ switch option {
+ case "NX":
+ log.Println("Handle NX")
+ if _, ok := cache.Load(cmd.args[1]); ok {
+ cmd.conn.Write([]uint8("$-1\r\n"))
+ return true
+ }
+ pos++
+ case "XX":
+ log.Println("Handle XX")
+ if _, ok := cache.Load(cmd.args[1]); !ok {
+ cmd.conn.Write([]uint8("$-1\r\n"))
+ return true
+ }
+ pos++
+ }
+ if len(cmd.args) > pos {
+ if err := cmd.setExpiration(pos); err != nil {
+ cmd.conn.Write([]uint8("-ERR " + err.Error() + "\r\n"))
+ return true
+ }
+ }
+ }
+ cache.Store(cmd.args[1], cmd.args[2])
+ cmd.conn.Write([]uint8("+OK\r\n"))
+ return true
+}
+
As always, first thing we do is validate the number of arguments.
+ But in this case, SET is more tricky than the others.
+
When more than 3 arguments are passed we check for the NX or XX flags and handle them accordingly.
+
- NX -- Only set the key if it does not already exist.
- XX -- Only set the key if it already exist.
Then we parse the expiration flags if any. We'll see how that's done in a second.
+
After handling all those special cases we finally store the key and value in the cache,
+ write the +OK response and return true to keep the connection alive.
+
Expiration
// setExpiration Handles expiration when passed as part of the 'set' command.
+func (cmd Command) setExpiration(pos int) error {
+ option := strings.ToUpper(cmd.args[pos])
+ value, _ := strconv.Atoi(cmd.args[pos+1])
+ var duration time.Duration
+ switch option {
+ case "EX":
+ duration = time.Second * time.Duration(value)
+ case "PX":
+ duration = time.Millisecond * time.Duration(value)
+ default:
+ return fmt.Errorf("expiration option is not valid")
+ }
+ go func() {
+ log.Printf("Handling '%s', sleeping for %v\n", option, duration)
+ time.Sleep(duration)
+ cache.Delete(cmd.args[1])
+ }()
+ return nil
+}
+
We read the option and the expiration value, then we compute the duration
+ for each case and we spawn a new goroutine that sleeps for that amount of
+ time and the deletes the key from the cache.
+
This is not the most efficient way to do it, but it's simple and it works for us.
+
Working server
At this point we have an usable implementation of Redis.
+
Let's start the server the server and test it.
+
$ go run main.go
+2023/04/08 21:09:40 Listening on tcp://0.0.0.0:6380
+
On a different terminal connect to the server.
+
$ telnet 127.0.0.1 6380
+GET a
+$-1
+set a "test \"quotes\" are working"
++OK
+get a
+$25
+test "quotes" are working
+
It's alive!! Go have fun.
+
If you'd like to access the source code of this project there's a public gist
+ containing all of the code displayed here.
+
Link to source code