From 4e224665454104f053ad1ee79504240b5b6d351d Mon Sep 17 00:00:00 2001 From: Taras Pashkin <32797206+tpashkin@users.noreply.github.com> Date: Fri, 6 Sep 2024 20:03:14 +0300 Subject: [PATCH] nbd netlink refactoring (#1963) nbd netlink refactoring - fix context leak when nl_socket_modify_cb fails - use clojure instead of a static method + custom context as a socket callback - move TNetlinkDevice methods outside of the class definition to reduce nesting - use nl_send_auto + nl_wait_for_ack instead of nl_send_sync to differentiate between send and recv errors - remove "Device" suffixes from TNetlinkDevice methods --- cloud/blockstore/libs/nbd/netlink_device.cpp | 321 ++++++++++--------- 1 file changed, 172 insertions(+), 149 deletions(-) diff --git a/cloud/blockstore/libs/nbd/netlink_device.cpp b/cloud/blockstore/libs/nbd/netlink_device.cpp index 07176cca60b..0b491d55021 100644 --- a/cloud/blockstore/libs/nbd/netlink_device.cpp +++ b/cloud/blockstore/libs/nbd/netlink_device.cpp @@ -19,6 +19,8 @@ namespace { using namespace NThreading; +using TResponseHandler = std::function; + //////////////////////////////////////////////////////////////////////////////// constexpr TStringBuf NBD_DEVICE_SUFFIX = "/dev/nbd"; @@ -71,6 +73,31 @@ class TNetlinkSocket { return Family; } + + template + void SetCallback(nl_cb_type type, F func) + { + auto arg = std::make_unique(std::move(func)); + if (int err = nl_socket_modify_cb( + Socket, + type, + NL_CB_CUSTOM, + TNetlinkSocket::ResponseHandler, + arg.get())) + { + throw TServiceError(E_FAIL) + << "unable to set socket callback: " << nl_geterror(err); + } + arg.release(); + } + + static int ResponseHandler(nl_msg* msg, void* arg) + { + auto func = std::unique_ptr( + static_cast(arg)); + + return (*func)(static_cast(nlmsg_data(nlmsg_hdr(msg)))); + } }; //////////////////////////////////////////////////////////////////////////////// @@ -126,9 +153,7 @@ class TNetlinkMessage ~TNetlinkMessage() { - if (Message) { - nlmsg_free(Message); - } + nlmsg_free(Message); } template @@ -147,12 +172,15 @@ class TNetlinkMessage void Send(nl_sock* socket) { - // send will free message even if it fails - auto* message = Message; - Message = nullptr; - if (int err = nl_send_sync(socket, message)) { + if (int err = nl_send_auto(socket, Message); err < 0) { + throw TServiceError(E_FAIL) + << "send error: " << nl_geterror(err); + } + if (int err = nl_wait_for_ack(socket)) { + // this is either recv error, or an actual error message received + // from the kernel throw TServiceError(E_FAIL) - << "unable to send message: " << nl_geterror(err); + << "recv error: " << nl_geterror(err); } } }; @@ -163,16 +191,6 @@ class TNetlinkDevice final : public IDevice , public std::enable_shared_from_this { -private: - struct THandlerContext - { - std::shared_ptr Device; - - THandlerContext(std::shared_ptr device) - : Device(std::move(device)) - {} - }; - private: const ILoggingServicePtr Logging; const TNetworkAddress ConnectAddress; @@ -192,109 +210,126 @@ class TNetlinkDevice final public: TNetlinkDevice( - ILoggingServicePtr logging, - TNetworkAddress connectAddress, - TString deviceName, - TDuration timeout, - TDuration deadConnectionTimeout, - bool reconfigure) - : Logging(std::move(logging)) - , ConnectAddress(std::move(connectAddress)) - , DeviceName(std::move(deviceName)) - , Timeout(timeout) - , DeadConnectionTimeout(deadConnectionTimeout) - , Reconfigure(reconfigure) - { - Log = Logging->CreateLog("BLOCKSTORE_NBD"); - } + ILoggingServicePtr logging, + TNetworkAddress connectAddress, + TString deviceName, + TDuration timeout, + TDuration deadConnectionTimeout, + bool reconfigure); - ~TNetlinkDevice() - { - Stop(false).GetValueSync(); - } + ~TNetlinkDevice(); - TFuture Start() override - { - try { - ParseIndex(); - ConnectSocket(); - ConnectDevice(); - } catch (const TServiceError& e) { - StartResult.SetValue(MakeError( - e.GetCode(), - TStringBuilder() << "unable to configure " << DeviceName - << ": " << e.what())); - } + TFuture Start() override; + TFuture Stop(bool deleteDevice) override; + TFuture Resize(ui64 deviceSizeInBytes) override; - // will be set asynchronously in Connect > HandleStatus > DoConnect - return StartResult.GetFuture(); - } +private: + void ParseIndex(); - TFuture Stop(bool deleteDevice) override - { - if (AtomicSwap(&ShouldStop, 1) == 1) { - return StopResult.GetFuture(); - } + void ConnectSocket(); + void DisconnectSocket(); - if (!deleteDevice) { - StopResult.SetValue(MakeError(S_OK)); - return StopResult.GetFuture(); - } + void Connect(); + void Disconnect(); + void DoConnect(bool connected); - try { - DisconnectDevice(); - DisconnectSocket(); - StopResult.SetValue(MakeError(S_OK)); - } catch (const TServiceError& e) { - StopResult.SetValue(MakeError( - e.GetCode(), - TStringBuilder() << "unable to disconnect " << DeviceName - << ": " << e.what())); - } + int StatusHandler(genlmsghdr* header); +}; - return StopResult.GetFuture(); +//////////////////////////////////////////////////////////////////////////////// + +TNetlinkDevice::TNetlinkDevice( + ILoggingServicePtr logging, + TNetworkAddress connectAddress, + TString deviceName, + TDuration timeout, + TDuration deadConnectionTimeout, + bool reconfigure) + : Logging(std::move(logging)) + , ConnectAddress(std::move(connectAddress)) + , DeviceName(std::move(deviceName)) + , Timeout(timeout) + , DeadConnectionTimeout(deadConnectionTimeout) + , Reconfigure(reconfigure) +{ + Log = Logging->CreateLog("BLOCKSTORE_NBD"); +} + +TNetlinkDevice::~TNetlinkDevice() +{ + Stop(false).GetValueSync(); +} + +TFuture TNetlinkDevice::Start() +{ + try { + ParseIndex(); + ConnectSocket(); + Connect(); + + } catch (const TServiceError& e) { + StartResult.SetValue(MakeError( + e.GetCode(), + TStringBuilder() + << "unable to configure " << DeviceName << ": " << e.what())); } - NThreading::TFuture Resize(ui64 deviceSizeInBytes) override - { - try { - TNetlinkSocket socket; - TNetlinkMessage message(socket.GetFamily(), NBD_CMD_RECONFIGURE); + // will be set asynchronously in Connect > HandleStatus > DoConnect + return StartResult.GetFuture(); +} - message.Put(NBD_ATTR_INDEX, DeviceIndex); - message.Put(NBD_ATTR_SIZE_BYTES, deviceSizeInBytes); +TFuture TNetlinkDevice::Stop(bool deleteDevice) +{ + if (AtomicSwap(&ShouldStop, 1) == 1) { + return StopResult.GetFuture(); + } - { - auto attr = message.Nest(NBD_ATTR_SOCKETS); - auto item = message.Nest(NBD_SOCK_ITEM); - message.Put(NBD_SOCK_FD, static_cast(Socket)); - } + if (!deleteDevice) { + StopResult.SetValue(MakeError(S_OK)); + return StopResult.GetFuture(); + } - message.Send(socket); - } catch (const TServiceError& e) { - return NThreading::MakeFuture(MakeError( - e.GetCode(), - TStringBuilder() - << "unable to resize " << DeviceName << ": " << e.what())); - } + try { + Disconnect(); + DisconnectSocket(); + StopResult.SetValue(MakeError(S_OK)); - return NThreading::MakeFuture(MakeError(S_OK)); + } catch (const TServiceError& e) { + StopResult.SetValue(MakeError( + e.GetCode(), + TStringBuilder() + << "unable to disconnect " << DeviceName << ": " << e.what())); } -private: - void ParseIndex(); + return StopResult.GetFuture(); +} - void ConnectSocket(); - void DisconnectSocket(); +TFuture TNetlinkDevice::Resize(ui64 deviceSizeInBytes) +{ + try { + TNetlinkSocket socket; + TNetlinkMessage message(socket.GetFamily(), NBD_CMD_RECONFIGURE); - void ConnectDevice(); - void DoConnectDevice(bool connected); - void DisconnectDevice(); + message.Put(NBD_ATTR_INDEX, DeviceIndex); + message.Put(NBD_ATTR_SIZE_BYTES, deviceSizeInBytes); - static int StatusHandler(nl_msg* message, void* argument); -}; + { + auto attr = message.Nest(NBD_ATTR_SOCKETS); + auto item = message.Nest(NBD_SOCK_ITEM); + message.Put(NBD_SOCK_FD, static_cast(Socket)); + } -//////////////////////////////////////////////////////////////////////////////// + message.Send(socket); + + } catch (const TServiceError& e) { + return MakeFuture(MakeError( + e.GetCode(), + TStringBuilder() + << "unable to resize " << DeviceName << ": " << e.what())); + } + + return MakeFuture(MakeError(S_OK)); +} void TNetlinkDevice::ParseIndex() { @@ -331,7 +366,33 @@ void TNetlinkDevice::DisconnectSocket() Socket.Close(); } -void TNetlinkDevice::DoConnectDevice(bool connected) +// queries device status eand registers callback that will connect +// or reconfigure (if Reconfigure == true) specified device +void TNetlinkDevice::Connect() +{ + TNetlinkSocket socket; + socket.SetCallback( + NL_CB_VALID, + [device = shared_from_this()] (auto* header) { + return device->StatusHandler(header); + }); + + TNetlinkMessage message(socket.GetFamily(), NBD_CMD_STATUS); + message.Put(NBD_ATTR_INDEX, DeviceIndex); + message.Send(socket); +} + +void TNetlinkDevice::Disconnect() +{ + STORAGE_INFO("disconnect " << DeviceName); + + TNetlinkSocket socket; + TNetlinkMessage message(socket.GetFamily(), NBD_CMD_DISCONNECT); + message.Put(NBD_ATTR_INDEX, DeviceIndex); + message.Send(socket); +} + +void TNetlinkDevice::DoConnect(bool connected) { try { auto command = NBD_CMD_CONNECT; @@ -383,45 +444,8 @@ void TNetlinkDevice::DoConnectDevice(bool connected) } } -void TNetlinkDevice::DisconnectDevice() -{ - STORAGE_INFO("disconnect " << DeviceName); - - TNetlinkSocket socket; - TNetlinkMessage message(socket.GetFamily(), NBD_CMD_DISCONNECT); - message.Put(NBD_ATTR_INDEX, DeviceIndex); - message.Send(socket); -} - -// queries device status and registers callback that will connect -// or reconfigure (if Reconfigure == true) specified device -void TNetlinkDevice::ConnectDevice() +int TNetlinkDevice::StatusHandler(genlmsghdr* header) { - TNetlinkSocket socket; - auto context = std::make_unique(shared_from_this()); - - if (int err = nl_socket_modify_cb( - socket, - NL_CB_VALID, - NL_CB_CUSTOM, - TNetlinkDevice::StatusHandler, - context.release())) // libnl doesn't throw - { - throw TServiceError(E_FAIL) - << "unable to set socket callback: " << nl_geterror(err); - } - - TNetlinkMessage message(socket.GetFamily(), NBD_CMD_STATUS); - message.Put(NBD_ATTR_INDEX, DeviceIndex); - message.Send(socket); -} - -int TNetlinkDevice::StatusHandler(nl_msg* message, void* argument) -{ - auto* header = static_cast(nlmsg_data(nlmsg_hdr(message))); - auto context = std::unique_ptr( - static_cast(argument)); - nlattr* attr[NBD_ATTR_MAX + 1] = {}; nlattr* deviceItem[NBD_DEVICE_ITEM_MAX + 1] = {}; nlattr* device[NBD_DEVICE_ATTR_MAX + 1] = {}; @@ -440,7 +464,7 @@ int TNetlinkDevice::StatusHandler(nl_msg* message, void* argument) genlmsg_attrlen(header, 0), NULL)) { - context->Device->StartResult.SetValue(MakeError( + StartResult.SetValue(MakeError( E_FAIL, TStringBuilder() << "unable to parse NBD_CMD_STATUS response: " << nl_geterror(err))); @@ -448,7 +472,7 @@ int TNetlinkDevice::StatusHandler(nl_msg* message, void* argument) } if (!attr[NBD_ATTR_DEVICE_LIST]) { - context->Device->StartResult.SetValue(MakeError( + StartResult.SetValue(MakeError( E_FAIL, "did not receive NBD_ATTR_DEVICE_LIST")); return NL_STOP; @@ -460,7 +484,7 @@ int TNetlinkDevice::StatusHandler(nl_msg* message, void* argument) attr[NBD_ATTR_DEVICE_LIST], deviceItemPolicy)) { - context->Device->StartResult.SetValue(MakeError( + StartResult.SetValue(MakeError( E_FAIL, TStringBuilder() << "unable to parse NBD_ATTR_DEVICE_LIST: " << nl_geterror(err))); @@ -468,7 +492,7 @@ int TNetlinkDevice::StatusHandler(nl_msg* message, void* argument) } if (!deviceItem[NBD_DEVICE_ITEM]) { - context->Device->StartResult.SetValue(MakeError( + StartResult.SetValue(MakeError( E_FAIL, "did not receive NBD_DEVICE_ITEM")); return NL_STOP; @@ -480,7 +504,7 @@ int TNetlinkDevice::StatusHandler(nl_msg* message, void* argument) deviceItem[NBD_DEVICE_ITEM], devicePolicy)) { - context->Device->StartResult.SetValue(MakeError( + StartResult.SetValue(MakeError( E_FAIL, TStringBuilder() << "unable to parse NBD_DEVICE_ITEM: " << nl_geterror(err))); @@ -488,14 +512,13 @@ int TNetlinkDevice::StatusHandler(nl_msg* message, void* argument) } if (!device[NBD_DEVICE_CONNECTED]) { - context->Device->StartResult.SetValue(MakeError( + StartResult.SetValue(MakeError( E_FAIL, "did not receive NBD_DEVICE_CONNECTED")); return NL_STOP; } - context->Device->DoConnectDevice(nla_get_u8(device[NBD_DEVICE_CONNECTED])); - + DoConnect(nla_get_u8(device[NBD_DEVICE_CONNECTED])); return NL_OK; }