Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support redis transaction #10

Open
wants to merge 21 commits into
base: unstable
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .github/workflows/kiwidb.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,25 @@ jobs:
- uses: actions/checkout@v4

- name: Install dependencies
run: |
brew update
brew install ccache

- name: Configure ccache
run: |
ccache --set-config=cache_dir=$HOME/.ccache
ccache --max-size=10G

- name: Restore ccache
uses: actions/cache@v3
with:
path: ~/.ccache
key: ${{ runner.os }}-ccache-${{ hashFiles('**/*.cpp','**/*.cc','**/*.c', '**/*.h') }}-clang
restore-keys: |
${{ runner.os }}-ccache-
ccache-

- name: Build
run: |
brew update
brew install ccache
Expand Down
Binary file added dump.rdb
Binary file not shown.
6 changes: 6 additions & 0 deletions src/base_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ void BaseCmd::Execute(PClient* client) {
if (!DoInitial(client)) {
return;
}

// Check whether the watch key is modified
if (HasFlag(kCmdFlagsWrite)) {
signalModifiedKey(client->Keys(), client->GetCurrentDB());
}

DoCmd(client);
}

Expand Down
2 changes: 1 addition & 1 deletion src/base_cmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ const std::string kCmdNameMSetnx = "msetnx";
const std::string kCmdNameMulti = "multi";
const std::string kCmdNameExec = "exec";
const std::string kCmdNameWatch = "watch";
const std::string kCmdNameUnwatch = "unwatch";
const std::string kCmdNameUnWatch = "unwatch";
const std::string kCmdNameDiscard = "discard";

// admin
Expand Down
94 changes: 55 additions & 39 deletions src/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "base_cmd.h"
#include "client.h"
#include "cmd_thread_pool_worker.h"
#include "config.h"
#include "env.h"
#include "kiwi.h"
Expand All @@ -23,6 +24,8 @@

namespace kiwi {

CmdTableManager cmd_table_manager;

Comment on lines +27 to +28
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure thread-safe access to the global command table manager.

The global cmd_table_manager variable could lead to race conditions in a multi-threaded environment. Consider:

  1. Making it thread-safe using a singleton pattern with proper synchronization
  2. Or moving it to a context object that manages its lifecycle and access
-CmdTableManager cmd_table_manager;
+class SafeCmdTableManager {
+private:
+    static std::mutex mutex_;
+    static std::unique_ptr<CmdTableManager> instance_;
+
+public:
+    static CmdTableManager& Instance() {
+        std::lock_guard<std::mutex> lock(mutex_);
+        if (!instance_) {
+            instance_ = std::make_unique<CmdTableManager>();
+        }
+        return *instance_;
+    }
+};
+
+std::mutex SafeCmdTableManager::mutex_;
+std::unique_ptr<CmdTableManager> SafeCmdTableManager::instance_;

Committable suggestion skipped: line range outside the PR's diff.

const ClientInfo ClientInfo::invalidClientInfo = {0, "", -1};

thread_local PClient* PClient::s_current = nullptr;
Expand Down Expand Up @@ -167,10 +170,6 @@ int PClient::HandlePacket(std::string&& data) {
}
}

for (const auto& item : params) {
FeedMonitors(item);
}

auto now = std::chrono::steady_clock::now();
time_stat_->SetEnqueueTs(now);

Expand All @@ -180,26 +179,9 @@ int PClient::HandlePacket(std::string&& data) {
g_kiwi->SubmitFast(std::make_shared<CmdThreadPoolTask>(shared_from_this(), std::move(params)));
}

// check transaction
// if (IsFlagOn(ClientFlag_multi)) {
// if (cmdName_ != kCmdNameMulti && cmdName_ != kCmdNameExec && cmdName_ != kCmdNameWatch &&
// cmdName_ != kCmdNameUnwatch && cmdName_ != kCmdNameDiscard) {
// if (!info->CheckParamsCount(static_cast<int>(params.size()))) {
// ERROR("queue failed: cmd {} has params {}", cmdName_, params.size());
// ReplyError(info ? PError_param : PError_unknowCmd, &reply_);
// FlagExecWrong();
// } else {
// if (!IsFlagOn(ClientFlag_wrongExec)) {
// queue_cmds_.push_back(params);
// }
//
// reply_.PushData("+QUEUED\r\n", 9);
// INFO("queue cmd {}", cmdName_);
// }
//
// return static_cast<int>(ptr - start);
// }
// }
// Propagate(params, GetCurrentDB());

// g_kiwi->SubmitFast(std::make_shared<CmdThreadPoolTask>(shared_from_this()));
Comment on lines +182 to +184
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Remove commented-out code.

If the propagation functionality is no longer needed, remove the commented code. If it's still required, implement it properly.

-  // Propagate(params, GetCurrentDB());
-
-  // g_kiwi->SubmitFast(std::make_shared<CmdThreadPoolTask>(shared_from_this()));


// check readonly slave and execute command
// PError err = PError_ok;
Expand Down Expand Up @@ -311,6 +293,22 @@ uint64_t PClient::GetUniqueID() const { return GetConnId(); }

ClientInfo PClient::GetClientInfo() const { return {GetUniqueID(), PeerIP().c_str(), PeerPort()}; }

bool PClient::CheckTransation(std::vector<std::string>& param) {
if (IsFlagOn(kClientFlagMulti)) {
if (cmdName_ != kCmdNameMulti && cmdName_ != kCmdNameExec && cmdName_ != kCmdNameWatch &&
cmdName_ != kCmdNameUnWatch && cmdName_ != kCmdNameDiscard) {
if (!IsFlagOn(kClientFlagWrongExec)) {
queue_cmds_.push_back(param);
}
INFO("queue cmd {}", cmdName_);
this->SetRes(CmdRes::kQueued);
g_kiwi->PushWriteTask(shared_from_this());
return true;
}
}
return false;
}
Comment on lines +296 to +310
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Fix typo in method name and add validation.

Issues found:

  1. Method name has a typo: "Transation" should be "Transaction"
  2. Missing validation for empty parameter vector

Apply this diff to fix the issues:

-bool PClient::CheckTransation(std::vector<std::string>& param) {
+bool PClient::CheckTransaction(std::vector<std::string>& param) {
+  if (param.empty()) {
+    return false;
+  }
  if (IsFlagOn(kClientFlagMulti)) {
    if (cmdName_ != kCmdNameMulti && cmdName_ != kCmdNameExec && cmdName_ != kCmdNameWatch &&
        cmdName_ != kCmdNameUnWatch && cmdName_ != kCmdNameDiscard) {
      if (!IsFlagOn(kClientFlagWrongExec)) {
        queue_cmds_.push_back(param);
      }
      INFO("queue cmd {}", cmdName_);
      this->SetRes(CmdRes::kQueued);
      g_kiwi->PushWriteTask(shared_from_this());
      return true;
    }
  }
  return false;
}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
bool PClient::CheckTransation(std::vector<std::string>& param) {
if (IsFlagOn(kClientFlagMulti)) {
if (cmdName_ != kCmdNameMulti && cmdName_ != kCmdNameExec && cmdName_ != kCmdNameWatch &&
cmdName_ != kCmdNameUnWatch && cmdName_ != kCmdNameDiscard) {
if (!IsFlagOn(kClientFlagWrongExec)) {
queue_cmds_.push_back(param);
}
INFO("queue cmd {}", cmdName_);
this->SetRes(CmdRes::kQueued);
g_kiwi->PushWriteTask(shared_from_this());
return true;
}
}
return false;
}
bool PClient::CheckTransaction(std::vector<std::string>& param) {
if (param.empty()) {
return false;
}
if (IsFlagOn(kClientFlagMulti)) {
if (cmdName_ != kCmdNameMulti && cmdName_ != kCmdNameExec && cmdName_ != kCmdNameWatch &&
cmdName_ != kCmdNameUnWatch && cmdName_ != kCmdNameDiscard) {
if (!IsFlagOn(kClientFlagWrongExec)) {
queue_cmds_.push_back(param);
}
INFO("queue cmd {}", cmdName_);
this->SetRes(CmdRes::kQueued);
g_kiwi->PushWriteTask(shared_from_this());
return true;
}
}
return false;
}


bool PClient::Watch(int dbno, const std::string& key) {
DEBUG("Client {} watch {}, db {}", name_, key, dbno);
return watch_keys_[dbno].insert(key).second;
Expand All @@ -321,7 +319,6 @@ bool PClient::NotifyDirty(int dbno, const std::string& key) {
INFO("client is already dirty {}", GetUniqueID());
return true;
}

if (watch_keys_[dbno].contains(key)) {
INFO("{} client become dirty because key {} in db {}", GetUniqueID(), key, dbno);
SetFlag(kClientFlagDirty);
Expand All @@ -338,29 +335,48 @@ bool PClient::Exec() {
this->ClearMulti();
this->ClearWatch();
};

DEBUG("Exec");
if (IsFlagOn(kClientFlagWrongExec)) {
return false;
}

if (IsFlagOn(kClientFlagDirty)) {
// FormatNullArray(&reply_);
AppendString("");
std::string message_ = "$-1\r\n";
resp_encode_->Reply(message_);
return true;
}
resp_encode_->ClearReply();
DEBUG("size : {}", queue_cmds_.size());
AppendArrayLen(queue_cmds_.size());
DEBUG("judge");
auto client = shared_from_this();
cmd_table_manager.InitCmdTable();
for (auto& cmd : queue_cmds_) {
SetCmdName(kstd::StringToLower(cmd[0]));
SetArgv(cmd);
kstd::StringToLower(client->cmdName_);
auto [cmdPtr, ret] = cmd_table_manager.GetCommand(client->CmdName(), client.get());

auto cmdstat_map = GetCommandStatMap();
CommandStatistics statistics;
if (cmdstat_map->find(cmd[0]) == cmdstat_map->end()) {
cmdstat_map->emplace(cmd[0], statistics);
}
auto now = std::chrono::steady_clock::now();
GetTimeStat()->SetDequeueTs(now);
cmdPtr->Execute(client.get());

// PreFormatMultiBulk(queue_cmds_.size(), &reply_);
// for (const auto& cmd : queue_cmds_) {
// DEBUG("EXEC {}, for client {}", cmd[0], UniqueId());
// const PCommandInfo* info = PCommandTable::GetCommandInfo(cmd[0]);
// PError err = PCommandTable::ExecuteCmd(cmd, info, &reply_);

// may dirty clients;
// if (err == PError_ok && (info->attr & PAttr_write)) {
// Propagate(cmd);
// }
// }
// Info Commandstats used
now = std::chrono::steady_clock::now();
GetTimeStat()->SetProcessDoneTs(now);
(*cmdstat_map)[cmd[0]].cmd_count_.fetch_add(1);
(*cmdstat_map)[cmd[0]].cmd_time_consuming_.fetch_add(GetTimeStat()->GetTotalTime());

FeedMonitors(cmd);
}
DEBUG("over");
g_kiwi->PushWriteTask(client);
// Propagate(client->params_, GetCurrentDB());
return true;
}

Expand Down
6 changes: 3 additions & 3 deletions src/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class PClient : public std::enable_shared_from_this<PClient> {
}
}

bool CheckTransation(std::vector<std::string>& param);
bool Watch(int dbno, const std::string& key);
bool NotifyDirty(int dbno, const std::string& key);
bool Exec();
Expand Down Expand Up @@ -239,8 +240,8 @@ class PClient : public std::enable_shared_from_this<PClient> {
std::unordered_set<std::string> pattern_channels_;

uint32_t flag_ = 0;
std::unordered_map<int32_t, std::unordered_set<std::string> > watch_keys_;
std::vector<std::vector<std::string> > queue_cmds_;
std::unordered_map<int32_t, std::unordered_set<std::string>> watch_keys_;
std::vector<std::vector<std::string>> queue_cmds_;

// blocked list
std::unordered_set<std::string> waiting_keys_;
Expand All @@ -262,7 +263,6 @@ class PClient : public std::enable_shared_from_this<PClient> {
time_t last_auth_ = 0;

ClientState state_;

uint64_t net_id_ = 0;
int8_t net_thread_index_ = 0;
net::SocketAddr addr_;
Expand Down
12 changes: 12 additions & 0 deletions src/cmd_table_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "cmd_set.h"
#include "cmd_zset.h"
#include "std_string.h"
#include "transaction.h"

namespace kiwi {

Expand Down Expand Up @@ -51,6 +52,10 @@ CmdTableManager::CmdTableManager() {
void CmdTableManager::InitCmdTable() {
std::unique_lock wl(mutex_);

if (cmds_->size() != 0) {
return;
}

// admin
ADD_COMMAND_GROUP(Config, -2);
ADD_SUBCOMMAND(Config, Get, -3);
Expand Down Expand Up @@ -198,6 +203,13 @@ void CmdTableManager::InitCmdTable() {
ADD_COMMAND(ZRevrank, 3);
ADD_COMMAND(ZRem, -3);
ADD_COMMAND(ZIncrby, 4);

// multi
ADD_COMMAND(Multi, 1);
ADD_COMMAND(Watch, -2);
ADD_COMMAND(UnWatch, 1);
ADD_COMMAND(Exec, 1);
ADD_COMMAND(Discard, 1);
}

std::pair<BaseCmd*, CmdRes> CmdTableManager::GetCommand(const std::string& cmdName, PClient* client) {
Expand Down
11 changes: 11 additions & 0 deletions src/cmd_thread_pool_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,17 @@ void CmdWorkThreadPoolWorker::Work() {
if (!cmdPtr) {
if (ret == CmdRes::kUnknownCmd) {
task->Client()->SetRes(CmdRes::kUnknownCmd, fmt::format("unknown command '{}'", param[0]));
task->Client()->FlagExecWrong();
WARN("client IP:{},port:{} unknown command '{}'", task->Client()->PeerIP(), task->Client()->PeerPort(),
param[0]);
} else if (ret == CmdRes::kUnknownSubCmd) {
task->Client()->SetRes(CmdRes::kUnknownSubCmd, task->Client()->argv_[1]);
task->Client()->FlagExecWrong();
WARN("client IP:{},port:{} unknown sub command '{}'", task->Client()->PeerIP(), task->Client()->PeerPort(),
task->Client()->argv_[1]);
} else {
task->Client()->SetRes(CmdRes::kWrongNum, param[0]);
task->Client()->FlagExecWrong();
WARN("client IP:{},port:{} unknown command '{}'", task->Client()->PeerIP(), task->Client()->PeerPort(),
param[0]);
}
Expand All @@ -52,10 +55,16 @@ void CmdWorkThreadPoolWorker::Work() {

if (!cmdPtr->CheckArg(task->Client()->ParamsSize())) {
task->Client()->SetRes(CmdRes::kWrongNum, param[0]);
task->Client()->FlagExecWrong();
g_kiwi->PushWriteTask(task->Client());
continue;
}

// check transaction
if (task->Client()->CheckTransation(param)) {
continue;
}

Comment on lines +63 to +67
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Transaction check implementation is correct but contains a typo

The transaction check implementation is properly integrated with appropriate error handling and state management. However, there is a consistent typo in the method name:

  • CheckTransation should be CheckTransaction (missing 'c')
🔗 Analysis chain

Verify transaction check implementation.

The new transaction check is added but its behavior needs verification, especially around error handling and state management.

Run this script to verify the transaction check implementation:

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify transaction check implementation and error handling
# Test: Search for transaction check usage and related error handling
ast-grep --pattern 'CheckTransation($$$)'

Length of output: 43


Script:

#!/bin/bash
# Search for both spellings and Client class implementation
echo "=== Searching for Transaction methods ==="
rg -p "CheckTransaction|CheckTransation" -t cpp -t hpp

echo -e "\n=== Looking for Client class definition ==="
ast-grep --pattern 'class Client {
  $$$
}'

echo -e "\n=== Searching for transaction-related code ==="
rg -p "transaction|Transaction" -t cpp -t hpp -C 2

Length of output: 516


Script:

#!/bin/bash
# Search with corrected file types and broader patterns
echo "=== Searching for Client class and methods ==="
rg -p "class.*Client" -g "*.{h,cc}"

echo -e "\n=== Searching for transaction check methods ==="
rg -p "Check.*Trans.*" -g "*.{h,cc}" -C 2

echo -e "\n=== Searching for transaction-related code ==="
rg -p "transaction|Transaction" -g "*.{h,cc}" -C 2

echo -e "\n=== Looking for the specific file content ==="
cat src/cmd_thread_pool_worker.cc

Length of output: 12283

auto cmdstat_map = task->Client()->GetCommandStatMap();
CommandStatistics statistics;
if (cmdstat_map->find(param[0]) == cmdstat_map->end()) {
Expand All @@ -71,6 +80,8 @@ void CmdWorkThreadPoolWorker::Work() {
(*cmdstat_map)[param[0]].cmd_count_.fetch_add(1);
(*cmdstat_map)[param[0]].cmd_time_consuming_.fetch_add(task->Client()->GetTimeStat()->GetTotalTime());

task->Client()->FeedMonitors(param);

g_kiwi->PushWriteTask(task->Client());
}
}
Expand Down
18 changes: 18 additions & 0 deletions src/resp/resp2_encode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ void Resp2Encode::SetRes(CmdRes ret, const std::string& content) {
case CmdRes::kInvalidBitOffsetInt:
SetLineString("-ERR bit offset is not an integer or out of range");
break;
case CmdRes::kInvalidBitPosArgument:
SetLineString("-ERR The bit argument must be 1 or 0.");
break;
case CmdRes::kWrongBitOpNotNum:
SetLineString("-ERR BITOP NOT must be called with a single source key.");
break;
case CmdRes::kInvalidFloat:
SetLineString("-ERR value is not a valid float");
break;
Expand Down Expand Up @@ -88,6 +94,18 @@ void Resp2Encode::SetRes(CmdRes ret, const std::string& content) {
AppendStringRaw(
fmt::format("-WRONGTYPE Operation against a key holding the wrong kind of value {}\r\n", content));
break;
case CmdRes::kDirtyExec:
AppendStringRaw("-ERR EXECABORT Transaction discarded because of previous errors.");
AppendStringRaw(CRLF);
break;
case CmdRes::kPErrorWatch:
AppendStringRaw("-ERR WATCH inside MULTI is not allowed");
AppendStringRaw(CRLF);
break;
case CmdRes::kQueued:
AppendStringRaw("+QUEUED");
AppendStringRaw(CRLF);
break;
case CmdRes::kNoAuth:
SetLineString("-NOAUTH Authentication required");
break;
Expand Down
5 changes: 5 additions & 0 deletions src/resp/resp_encode.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ enum class CmdRes : std::int8_t {
kInvalidInt,
kInvalidBitInt,
kInvalidBitOffsetInt,
kInvalidBitPosArgument,
kWrongBitOpNotNum,
kInvalidFloat,
kOverFlow,
kNotFound,
Expand All @@ -38,6 +40,9 @@ enum class CmdRes : std::int8_t {
kInvalidCursor,
kWrongLeader,
kMultiKey,
kDirtyExec,
kPErrorWatch,
kQueued,
kNoAuth,
};

Expand Down
31 changes: 31 additions & 0 deletions src/store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "db.h"
#include "std/log.h"
#include "std/std_string.h"
#include "transaction.h"

namespace kiwi {

Expand Down Expand Up @@ -76,4 +77,34 @@ void Store::HandleTaskSpecificDB(const TasksVector& tasks) {
}
});
}

void Propagate(const std::vector<PString>& params, int dbno) {
assert(!params.empty());
//
// if (!g_dirtyKeys.empty()) {
// for (const auto& k : g_dirtyKeys) {
// PTransaction::Instance().NotifyDirty(PSTORE.GetDBNumber(), k);
//
// }
// g_dirtyKeys.clear();
// } else if (params.size() > 1) {
// PTransaction::Instance().NotifyDirty(PSTORE.GetDBNumber(), params[1]);
// }
if (params.size() > 1) {
PTransaction::Instance().NotifyDirty(dbno, params[1]);
}
}
Comment on lines +81 to +96
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Validate params size before accessing elements

In Propagate, you assume params.size() > 1 to access params[1]. While there's an assert, it's better to handle this explicitly to prevent potential out-of-bounds access in release builds where assertions may be disabled.

Modify the code to include a size check:

void Propagate(const std::vector<PString>& params, int dbno) {
  if (params.size() > 1) {
    PTransaction::Instance().NotifyDirty(dbno, params[1]);
  } else {
    // Handle insufficient parameters
  }
}


void Propagate(int dbno, const std::vector<PString>& params) {
PTransaction::Instance().NotifyDirtyAll(dbno);
Propagate(params, dbno);
}

void signalModifiedKey(const std::vector<PString>& keys, int dbno) {
if (keys.size() > 1) {
for (const auto& key : keys) {
PTransaction::Instance().NotifyDirty(dbno, key);
}
}
Comment on lines +103 to +108
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure all modified keys are signaled regardless of count

In signalModifiedKey, you're only notifying dirty keys when keys.size() > 1. This means that if there's only one key, it won't be notified, which may not be the intended behavior.

Update the condition to include all non-empty key vectors:

void signalModifiedKey(const std::vector<PString>& keys, int dbno) {
  if (!keys.empty()) {
    for (const auto& key : keys) {
      PTransaction::Instance().NotifyDirty(dbno, key);
    }
  }
}

}
} // namespace kiwi
Loading
Loading