From f3acf2dc0d8fb84bcf9a1acf30e917a379fc51d7 Mon Sep 17 00:00:00 2001 From: Kaijie Gu <2459548460@qq.com> Date: Sat, 16 Nov 2024 19:11:29 +0800 Subject: [PATCH] feat: add client command (#3) add client command --- src/base_cmd.h | 7 ++ src/client.cc | 14 ++-- src/client.h | 12 +++- src/client_map.cc | 103 +++++++++++++++++++++++++++++ src/client_map.h | 41 ++++++++++++ src/cmd_admin.cc | 136 +++++++++++++++++++++++++++++++++++++++ src/cmd_admin.h | 78 ++++++++++++++++++++++ src/cmd_table_manager.cc | 7 ++ src/kiwi.cc | 5 ++ src/proto_parser.cc | 3 +- tests/admin_test.go | 23 +++++++ 11 files changed, 421 insertions(+), 8 deletions(-) create mode 100644 src/client_map.cc create mode 100644 src/client_map.h mode change 100755 => 100644 src/proto_parser.cc diff --git a/src/base_cmd.h b/src/base_cmd.h index 5329583..06d0d24 100644 --- a/src/base_cmd.h +++ b/src/base_cmd.h @@ -78,6 +78,13 @@ const std::string kCmdNameUnwatch = "unwatch"; const std::string kCmdNameDiscard = "discard"; // admin +const std::string kCmdNameClient = "client"; +const std::string kSubCmdNameClientGetname = "getname"; +const std::string kSubCmdNameClientSetname = "setname"; +const std::string kSubCmdNameClientId = "id"; +const std::string kSubCmdNameClientList = "list"; +const std::string kSubCmdNameClientKill = "kill"; + const std::string kCmdNameConfig = "config"; const std::string kSubCmdNameConfigGet = "get"; const std::string kSubCmdNameConfigSet = "set"; diff --git a/src/client.cc b/src/client.cc index a2485da..0f88a80 100644 --- a/src/client.cc +++ b/src/client.cc @@ -27,6 +27,8 @@ namespace kiwi { +const ClientInfo ClientInfo::invalidClientInfo = {0, "", -1}; + void CmdRes::RedisAppendLen(std::string& str, int64_t ori, const std::string& prefix) { str.append(prefix); str.append(pstd::Int2string(ori)); @@ -459,7 +461,7 @@ void PClient::OnConnect() { std::string PClient::PeerIP() const { if (!addr_.IsValid()) { - ERROR("Invalid address detected for client {}", uniqueID()); + ERROR("Invalid address detected for client {}", GetUniqueID()); return ""; } return addr_.GetIP(); @@ -467,7 +469,7 @@ std::string PClient::PeerIP() const { int PClient::PeerPort() const { if (!addr_.IsValid()) { - ERROR("Invalid address detected for client {}", uniqueID()); + ERROR("Invalid address detected for client {}", GetUniqueID()); return 0; } return addr_.GetPort(); @@ -514,7 +516,9 @@ bool PClient::isClusterCmdTarget() const { return PRAFT.GetClusterCmdCtx().GetPeerIp() == PeerIP() && PRAFT.GetClusterCmdCtx().GetPort() == PeerPort(); } -uint64_t PClient::uniqueID() const { return GetConnId(); } +uint64_t PClient::GetUniqueID() const { return GetConnId(); } + +ClientInfo PClient::GetClientInfo() const { return {GetUniqueID(), PeerIP().c_str(), PeerPort()}; } bool PClient::Watch(int dbno, const std::string& key) { DEBUG("Client {} watch {}, db {}", name_, key, dbno); @@ -523,12 +527,12 @@ bool PClient::Watch(int dbno, const std::string& key) { bool PClient::NotifyDirty(int dbno, const std::string& key) { if (IsFlagOn(kClientFlagDirty)) { - INFO("client is already dirty {}", uniqueID()); + INFO("client is already dirty {}", GetUniqueID()); return true; } if (watch_keys_[dbno].contains(key)) { - INFO("{} client become dirty because key {} in db {}", uniqueID(), key, dbno); + INFO("{} client become dirty because key {} in db {}", GetUniqueID(), key, dbno); SetFlag(kClientFlagDirty); return true; } else { diff --git a/src/client.h b/src/client.h index 54ee98d..b596e63 100644 --- a/src/client.h +++ b/src/client.h @@ -157,6 +157,14 @@ enum class ClientState { class DB; struct PSlaveInfo; +struct ClientInfo { + uint64_t client_id; + std::string ip; + int port; + static const ClientInfo invalidClientInfo; + bool operator==(const ClientInfo& ci) const { return client_id == ci.client_id; } +}; + class PClient : public std::enable_shared_from_this, public CmdRes { public: // PClient() = delete; @@ -168,6 +176,8 @@ class PClient : public std::enable_shared_from_this, public CmdRes { std::string PeerIP() const; int PeerPort() const; + const int GetFd() const; + ClientInfo GetClientInfo() const; // bool SendPacket(const std::string& buf); // bool SendPacket(const void* data, size_t size); @@ -256,6 +266,7 @@ class PClient : public std::enable_shared_from_this, public CmdRes { void SetAuth() { auth_ = true; } bool GetAuth() const { return auth_; } + uint64_t GetUniqueID() const; void RewriteCmd(std::vector& params) { parser_.SetParams(params); } void Reexecutecommand() { this->executeCommand(); } @@ -287,7 +298,6 @@ class PClient : public std::enable_shared_from_this, public CmdRes { int processInlineCmd(const char*, size_t, std::vector&); void reset(); bool isPeerMaster() const; - uint64_t uniqueID() const; bool isClusterCmdTarget() const; diff --git a/src/client_map.cc b/src/client_map.cc new file mode 100644 index 0000000..9cc672b --- /dev/null +++ b/src/client_map.cc @@ -0,0 +1,103 @@ +#include "client_map.h" +#include "log.h" + +namespace kiwi { + +uint32_t ClientMap::GetAllClientInfos(std::vector& results) { + // client info string type: ip, port, fd. + std::shared_lock client_map_lock(client_map_mutex_); + for (auto& [id, client_weak] : clients_) { + if (auto client = client_weak.lock()) { + results.emplace_back(client->GetClientInfo()); + } + } + return results.size(); +} + +bool ClientMap::AddClient(int id, std::weak_ptr client) { + std::unique_lock client_map_lock(client_map_mutex_); + if (clients_.find(id) == clients_.end()) { + clients_.insert({id, client}); + return true; + } + return false; +} + +ClientInfo ClientMap::GetClientsInfoById(int id) { + std::shared_lock client_map_lock(client_map_mutex_); + if (auto it = clients_.find(id); it != clients_.end()) { + if (auto client = it->second.lock(); client) { + return client->GetClientInfo(); + } + } + ERROR("Client with ID {} not found in GetClientsInfoById", id); + return ClientInfo::invalidClientInfo; +} + +bool ClientMap::RemoveClientById(int id) { + std::unique_lock client_map_lock(client_map_mutex_); + if (auto it = clients_.find(id); it != clients_.end()) { + clients_.erase(it); + INFO("Removed client with ID {}", id); + return true; + } + return false; +} + +bool ClientMap::KillAllClients() { + std::vector> clients_to_close; + { + std::shared_lock client_map_lock(client_map_mutex_); + for (auto& [id, client_weak] : clients_) { + if (auto client = client_weak.lock()) { + clients_to_close.push_back(client); + } + } + } + for (auto& client : clients_to_close) { + client->Close(); + } + return true; +} + +bool ClientMap::KillClientByAddrPort(const std::string& addr_port) { + std::shared_ptr client_to_close; + { + std::shared_lock client_map_lock(client_map_mutex_); + for (auto& [id, client_weak] : clients_) { + if (auto client = client_weak.lock()) { + std::string client_ip_port = client->PeerIP() + ":" + std::to_string(client->PeerPort()); + if (client_ip_port == addr_port) { + client_to_close = client; + break; + } + } + } + } + if (client_to_close) { + client_to_close->Close(); + return true; + } + return false; +} + +bool ClientMap::KillClientById(int client_id) { + std::shared_ptr client_to_close; + { + std::shared_lock client_map_lock(client_map_mutex_); + if (auto it = clients_.find(client_id); it != clients_.end()) { + if (auto client = it->second.lock()) { + client_to_close = client; + } + } + } + if (client_to_close) { + INFO("Closing client with ID {}", client_id); + client_to_close->Close(); + INFO("Client with ID {} closed", client_id); + return true; + } + return false; +} + +} // namespace kiwi \ No newline at end of file diff --git a/src/client_map.h b/src/client_map.h new file mode 100644 index 0000000..ddbfa23 --- /dev/null +++ b/src/client_map.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include +#include +#include +#include "client.h" + +namespace kiwi { +class ClientMap { + private: + ClientMap() = default; + // 禁用复制构造函数和赋值运算符 + + private: + std::map> clients_; + std::shared_mutex client_map_mutex_; + + public: + static ClientMap& getInstance() { + static ClientMap instance; + return instance; + } + + ClientMap(const ClientMap&) = delete; + ClientMap& operator=(const ClientMap&) = delete; + + // client info function + kiwi::ClientInfo GetClientsInfoById(int id); + uint32_t GetAllClientInfos(std::vector& results); + + bool AddClient(int id, std::weak_ptr); + + bool RemoveClientById(int id); + + bool KillAllClients(); + bool KillClientById(int client_id); + bool KillClientByAddrPort(const std::string& addr_port); +}; + +} // namespace kiwi \ No newline at end of file diff --git a/src/cmd_admin.cc b/src/cmd_admin.cc index 3e214d6..53d38e9 100644 --- a/src/cmd_admin.cc +++ b/src/cmd_admin.cc @@ -35,6 +35,7 @@ #include "praft/praft.h" #include "pstd/env.h" +#include "client_map.h" #include "cmd_table_manager.h" #include "slow_log.h" #include "store.h" @@ -644,6 +645,141 @@ void SortCmd::InitialArgument() { get_patterns_.clear(); ret_.clear(); } +CmdClient::CmdClient(const std::string& name, int arity) + : BaseCmdGroup(name, kCmdFlagsReadonly | kCmdFlagsAdmin, kAclCategoryAdmin) {} + +bool CmdClient::HasSubCommand() const { return true; } + +CmdClientGetname::CmdClientGetname(const std::string& name, int16_t arity) + : BaseCmd(name, arity, kCmdFlagsAdmin | kCmdFlagsReadonly, kAclCategoryAdmin) {} + +bool CmdClientGetname::DoInitial(PClient* client) { return true; } + +void CmdClientGetname::DoCmd(PClient* client) { client->AppendString(client->GetName()); } + +CmdClientSetname::CmdClientSetname(const std::string& name, int16_t arity) + : BaseCmd(name, arity, kCmdFlagsAdmin | kCmdFlagsWrite, kAclCategoryAdmin) {} + +bool CmdClientSetname::DoInitial(PClient* client) { return true; } + +void kiwi::CmdClientSetname::DoCmd(PClient* client) { + client->SetName(client->argv_[2]); + client->SetRes(CmdRes::kOK); +} + +CmdClientId::CmdClientId(const std::string& name, int16_t arity) + : BaseCmd(name, arity, kCmdFlagsAdmin | kCmdFlagsReadonly, kAclCategoryAdmin) {} + +bool CmdClientId::DoInitial(PClient* client) { return true; } + +void CmdClientId::DoCmd(PClient* client) { client->AppendInteger(client->GetUniqueID()); } + +CmdClientKill::CmdClientKill(const std::string& name, int16_t arity) + : BaseCmd(name, arity, kCmdFlagsAdmin, kAclCategoryAdmin) {} + +bool CmdClientKill::DoInitial(PClient* client) { + if (client->argv_.size() == 3 && strcasecmp(client->argv_[2].data(), "all") == 0) { + kill_type_ = Type::ALL; + return true; + } else if (client->argv_.size() == 4 && strcasecmp(client->argv_[2].data(), "addr") == 0) { + kill_type_ = Type::ADDR; + return true; + } else if (client->argv_.size() == 4 && strcasecmp(client->argv_[2].data(), "id") == 0) { + kill_type_ = Type::ID; + return true; + } else { + client->SetRes(CmdRes::kWrongNum, client->CmdName()); + return false; + } +} + +void CmdClientKill::DoCmd(PClient* client) { + bool ret; + auto& client_map = kiwi::ClientMap::getInstance(); + switch (kill_type_) { + case Type::ALL: { + ret = client_map.KillAllClients(); + break; + } + case Type::ADDR: { + ret = client_map.KillClientByAddrPort(client->argv_[3]); + break; + } + case Type::ID: { + try { + int client_id = stoi(client->argv_[3]); + ret = client_map.KillClientById(client_id); + } catch (const std::exception& e) { + client->SetRes(CmdRes::kErrOther, "Invalid client id"); + return; + } + } + default: + break; + } + ret == true ? client->SetRes(CmdRes::kOK) : client->SetRes(CmdRes::kErrOther, "No such client"); +} + +CmdClientList::CmdClientList(const std::string& name, int16_t arity) + : BaseCmd(name, arity, kCmdFlagsAdmin | kCmdFlagsReadonly, kAclCategoryAdmin) {} + +bool CmdClientList::DoInitial(PClient* client) { + if (client->argv_.size() == 2) { + list_type_ = Type::DEFAULT; + return true; + } + if (client->argv_.size() > 3 && strcasecmp(client->argv_[2].data(), "id") == 0) { + list_type_ = Type::ID; + return true; + } + client->SetRes(CmdRes::kErrOther, "Syntax error, try CLIENT (LIST [ID client_id_1, client_id_2...])"); + return false; +} + +void CmdClientList::DoCmd(PClient* client) { + auto& client_map = ClientMap::getInstance(); + switch (list_type_) { + case Type::DEFAULT: { + std::vector client_infos; + client_map.GetAllClientInfos(client_infos); + client->AppendArrayLen(client_infos.size()); + if (client_infos.size() == 0) { + return; + } + char buf[128]; + for (auto& client_info : client_infos) { + // client-> + snprintf(buf, sizeof(buf), "ID=%ld IP=%s PORT=%d\n", client_info.client_id, client_info.ip.c_str(), + client_info.port); + client->AppendString(std::string(buf)); + } + break; + } + case Type::ID: { + client->AppendArrayLen(client->argv_.size() - 3); + + for (size_t i = 3; i < client->argv_.size(); i++) { + try { + int client_id = std::stoi(client->argv_[i]); + auto client_info = client_map.GetClientsInfoById(client_id); + if (client_info == ClientInfo::invalidClientInfo) { + client->SetRes(CmdRes::kErrOther, "Invalid client id"); + return; + } + std::string result = + fmt::format("ID={} IP={} PORT={}\n", client_info.client_id, client_info.ip, client_info.port); + client->AppendString(result); + } catch (const std::exception& e) { + client->SetRes(CmdRes::kErrOther, "Invalid client id"); + return; + } + } + break; + } + default: + break; + } +} MonitorCmd::MonitorCmd(const std::string& name, int arity) : BaseCmd(name, arity, kCmdFlagsReadonly | kCmdFlagsAdmin, kAclCategoryAdmin) {} diff --git a/src/cmd_admin.h b/src/cmd_admin.h index 37f23c1..5097182 100644 --- a/src/cmd_admin.h +++ b/src/cmd_admin.h @@ -84,6 +84,84 @@ class FlushallCmd : public BaseCmd { void DoCmd(PClient* client) override; }; +class CmdClient : public BaseCmdGroup { + public: + CmdClient(const std::string& name, int arity); + bool HasSubCommand() const override; + + protected: + std::string operation_, info_; + bool DoInitial(PClient* client) override { return true; } + + private: + const static std::string CLIENT_LIST_S; + const static std::string CLIENT_KILL_S; + + void DoCmd(PClient* client) override {} +}; + +class CmdClientGetname : public BaseCmd { + public: + CmdClientGetname(const std::string& name, int16_t arity); + + protected: + bool DoInitial(PClient* client) override; + + private: + void DoCmd(PClient* client) override; +}; + +class CmdClientSetname : public BaseCmd { + public: + CmdClientSetname(const std::string& name, int16_t arity); + + protected: + bool DoInitial(PClient* client) override; + + private: + void DoCmd(PClient* client) override; +}; + +class CmdClientId : public BaseCmd { + public: + CmdClientId(const std::string& name, int16_t arity); + + protected: + bool DoInitial(PClient* client) override; + + private: + void DoCmd(PClient* client) override; +}; + +class CmdClientList : public BaseCmd { + private: + enum class Type { DEFAULT, IDLE, ADDR, ID } list_type_; + std::string info_; + + public: + CmdClientList(const std::string& name, int16_t arity); + + protected: + bool DoInitial(PClient* client) override; + + private: + void DoCmd(PClient* client) override; +}; + +class CmdClientKill : public BaseCmd { + private: + enum class Type { ALL, ADDR, ID } kill_type_; + + public: + CmdClientKill(const std::string& name, int16_t arity); + + protected: + bool DoInitial(PClient* client) override; + + private: + void DoCmd(PClient* client) override; +}; + class SelectCmd : public BaseCmd { public: SelectCmd(const std::string& name, int16_t arity); diff --git a/src/cmd_table_manager.cc b/src/cmd_table_manager.cc index 589ccd9..d8b2543 100644 --- a/src/cmd_table_manager.cc +++ b/src/cmd_table_manager.cc @@ -63,6 +63,13 @@ void CmdTableManager::InitCmdTable() { ADD_COMMAND(Sort, -2); ADD_COMMAND(Monitor, 1); + ADD_COMMAND_GROUP(Client, -2); + ADD_SUBCOMMAND(Client, Getname, 2); + ADD_SUBCOMMAND(Client, Setname, 3); + ADD_SUBCOMMAND(Client, Id, 2); + ADD_SUBCOMMAND(Client, List, -2); + ADD_SUBCOMMAND(Client, Kill, -3); + // server ADD_COMMAND(Flushdb, 1); ADD_COMMAND(Flushall, 1); diff --git a/src/kiwi.cc b/src/kiwi.cc index 1fc5455..46a90f1 100644 --- a/src/kiwi.cc +++ b/src/kiwi.cc @@ -22,6 +22,7 @@ #include #include "client.h" +#include "client_map.h" #include "config.h" #include "helper.h" #include "kiwi.h" @@ -155,6 +156,8 @@ void KiwiDB::OnNewConnection(uint64_t connId, std::shared_ptr& cl INFO("New connection from {}:{}", addr.GetIP(), addr.GetPort()); client->SetSocketAddr(addr); client->OnConnect(); + // add new PClient to clients + ClientMap::getInstance().AddClient(client->GetUniqueID(), client); } bool KiwiDB::Init() { @@ -206,6 +209,7 @@ bool KiwiDB::Init() { event_server_->SetOnCreate([](uint64_t connID, std::shared_ptr& client, const net::SocketAddr& addr) { client->SetSocketAddr(addr); client->OnConnect(); + ClientMap::getInstance().AddClient(client->GetUniqueID(), client); INFO("New connection from fd:{} IP:{} port:{}", connID, addr.GetIP(), addr.GetPort()); }); @@ -216,6 +220,7 @@ bool KiwiDB::Init() { event_server_->SetOnClose([](std::shared_ptr& client, std::string&& msg) { INFO("Close connection id:{} msg:{}", client->GetConnId(), msg); client->OnClose(); + ClientMap::getInstance().RemoveClientById(client->GetUniqueID()); }); event_server_->InitTimer(10); diff --git a/src/proto_parser.cc b/src/proto_parser.cc old mode 100755 new mode 100644 index c0e7844..b13d6e3 --- a/src/proto_parser.cc +++ b/src/proto_parser.cc @@ -26,7 +26,6 @@ void PProtoParser::Reset() { numOfParam_ = 0; params_.clear(); - } PParseResult PProtoParser::ParseRequest(const char*& ptr, const char* end) { @@ -96,7 +95,7 @@ PParseResult PProtoParser::parseStrval(const char*& ptr, const char* end, PStrin assert(paramLen_ >= 0); if (static_cast(end - ptr) < paramLen_ + 2) { - paramLen_-=(end-ptr); + paramLen_ -= (end - ptr); result.append(ptr, end - ptr); return PParseResult::kWait; } diff --git a/tests/admin_test.go b/tests/admin_test.go index 1bdc6d1..154cc38 100644 --- a/tests/admin_test.go +++ b/tests/admin_test.go @@ -252,6 +252,29 @@ var _ = Describe("Admin", Ordered, func() { Expect(del2.Err()).NotTo(HaveOccurred()) }) + It("Cmd Client", func() { + conn := client.Conn() + set := conn.ClientSetName(ctx, "clientxxx") + Expect(set.Err()).NotTo(HaveOccurred()) + Expect(set.Val()).To(Equal(true)) + + get := conn.ClientGetName(ctx) + Expect(get.Err()).NotTo(HaveOccurred()) + Expect(get.Val()).To(Equal("clientxxx")) + + resId := conn.ClientID(ctx).Err() + Expect(resId).NotTo(HaveOccurred()) + Expect(client.ClientID(ctx).Val()).To(BeNumerically(">=", 0)) + + resKillFilter := conn.ClientKillByFilter(ctx, "ADDR", "1.1.1.1:1111") + Expect(resKillFilter.Err()).To(MatchError("ERR No such client")) + Expect(resKillFilter.Val()).To(Equal(int64(0))) + + resKillFilter = conn.ClientKillByFilter(ctx, "ID", "1") + Expect(resKillFilter.Err()).To(MatchError("ERR No such client")) + Expect(resKillFilter.Val()).To(Equal(int64(0))) + }) + // It("should monitor", Label("monitor"), func() { // ress := make(chan string) // client1 := s.NewClient()