diff --git a/source/extensions/load_balancing_policies/client_side_weighted_round_robin/client_side_weighted_round_robin_lb.cc b/source/extensions/load_balancing_policies/client_side_weighted_round_robin/client_side_weighted_round_robin_lb.cc index a4380abad549..356806e4efa8 100644 --- a/source/extensions/load_balancing_policies/client_side_weighted_round_robin/client_side_weighted_round_robin_lb.cc +++ b/source/extensions/load_balancing_policies/client_side_weighted_round_robin/client_side_weighted_round_robin_lb.cc @@ -16,54 +16,38 @@ #include "absl/status/status.h" #include "xds/data/orca/v3/orca_load_report.pb.h" -using Envoy::MonotonicTime; -using ::envoy::extensions::load_balancing_policies::client_side_weighted_round_robin::v3:: - ClientSideWeightedRoundRobin; -using Envoy::Upstream::Host; -using xds::data::orca::v3::OrcaLoadReport; - -#if TEST_THREAD_SUPPORTED -#define IS_MAIN_OR_TEST_THREAD() (Envoy::Thread::MainThread::isMainOrTestThread()) -#else // !TEST_THREAD_SUPPORTED -- just check for the main thread. -#define IS_MAIN_OR_TEST_THREAD() (Envoy::Thread::MainThread::isMainThread()) -#endif // TEST_THREAD_SUPPORTED +namespace Envoy { +namespace Upstream { namespace { - std::string getHostAddress(const Host* host) { if (host == nullptr || host->address() == nullptr) { return "unknown"; } return host->address()->asString(); } - } // namespace -namespace Envoy { -namespace Upstream { - -ClientSideWeightedRoundRobinLoadBalancer::ClientSideWeightedRoundRobinLoadBalancer( +ClientSideWeightedRoundRobinLoadBalancer::WorkerLocalLb::WorkerLocalLb( const PrioritySet& priority_set, const PrioritySet* local_priority_set, ClusterLbStats& stats, Runtime::Loader& runtime, Random::RandomGenerator& random, const envoy::config::cluster::v3::Cluster::CommonLbConfig& common_config, - const ClientSideWeightedRoundRobin& client_side_weighted_round_robin_config, - TimeSource& time_source, Event::Dispatcher& main_thread_dispatcher) + const ClientSideWeightedRoundRobinLbProto& client_side_weighted_round_robin_config, + TimeSource& time_source) : EdfLoadBalancerBase( priority_set, local_priority_set, stats, runtime, random, PROTOBUF_PERCENT_TO_ROUNDED_INTEGER_OR_DEFAULT(common_config, healthy_panic_threshold, 100, 50), LoadBalancerConfigHelper::localityLbConfigFromCommonLbConfig(common_config), /*slow_start_config=*/std::nullopt, time_source) { - ENVOY_LOG(trace, "RoundRobinLbConfig config {}", - client_side_weighted_round_robin_config.DebugString()); - initFromConfig(client_side_weighted_round_robin_config); - if (IS_MAIN_OR_TEST_THREAD()) { - startWeightUpdatesOnMainThread(main_thread_dispatcher); - } + initialize(); + orca_load_report_handler_ = std::make_shared( + client_side_weighted_round_robin_config, time_source_); } // {LoadBalancer} Interface implementation. -void ClientSideWeightedRoundRobinLoadBalancer::refreshHostSource(const HostsSource& source) { +void ClientSideWeightedRoundRobinLoadBalancer::WorkerLocalLb::refreshHostSource( + const HostsSource& source) { // insert() is used here on purpose so that we don't overwrite the index if // the host source already exists. Note that host sources will never be // removed, but given how uncommon this is it probably doesn't matter. @@ -71,33 +55,25 @@ void ClientSideWeightedRoundRobinLoadBalancer::refreshHostSource(const HostsSour // If the list of hosts changes, the order of picks change. Discard the // index. peekahead_index_ = 0; - - if (!IS_MAIN_OR_TEST_THREAD()) { - return; - } - - // On the main thread ensure that all hosts have client side lb policy data. - addClientSideLbPolicyDataToHosts(hostSourceToHosts(source)); } HostConstSharedPtr -ClientSideWeightedRoundRobinLoadBalancer::chooseHost(LoadBalancerContext* context) { +ClientSideWeightedRoundRobinLoadBalancer::WorkerLocalLb::chooseHost(LoadBalancerContext* context) { HostConstSharedPtr host = EdfLoadBalancerBase::chooseHost(context); if (context != nullptr) { - // Configure callbacks to receive ORCA load reports. + // Configure callbacks to receive ORCA load report. context->setOrcaLoadReportCallbacks(orca_load_report_handler_); } return host; } -double ClientSideWeightedRoundRobinLoadBalancer::hostWeight(const Host& host) const { +double ClientSideWeightedRoundRobinLoadBalancer::WorkerLocalLb::hostWeight(const Host& host) const { ENVOY_LOG(trace, "hostWeight {} = {}", getHostAddress(&host), host.weight()); return host.weight(); } -HostConstSharedPtr -ClientSideWeightedRoundRobinLoadBalancer::unweightedHostPeek(const HostVector& hosts_to_use, - const HostsSource& source) { +HostConstSharedPtr ClientSideWeightedRoundRobinLoadBalancer::WorkerLocalLb::unweightedHostPeek( + const HostVector& hosts_to_use, const HostsSource& source) { auto i = rr_indexes_.find(source); if (i == rr_indexes_.end()) { return nullptr; @@ -105,9 +81,8 @@ ClientSideWeightedRoundRobinLoadBalancer::unweightedHostPeek(const HostVector& h return hosts_to_use[(i->second + (peekahead_index_)++) % hosts_to_use.size()]; } -HostConstSharedPtr -ClientSideWeightedRoundRobinLoadBalancer::unweightedHostPick(const HostVector& hosts_to_use, - const HostsSource& source) { +HostConstSharedPtr ClientSideWeightedRoundRobinLoadBalancer::WorkerLocalLb::unweightedHostPick( + const HostVector& hosts_to_use, const HostsSource& source) { if (peekahead_index_ > 0) { --peekahead_index_; } @@ -120,7 +95,7 @@ ClientSideWeightedRoundRobinLoadBalancer::unweightedHostPick(const HostVector& h } ClientSideWeightedRoundRobinLoadBalancer::OrcaLoadReportHandler::OrcaLoadReportHandler( - const ClientSideWeightedRoundRobin& client_side_weighted_round_robin_config, + const ClientSideWeightedRoundRobinLbProto& client_side_weighted_round_robin_config, TimeSource& time_source) : time_source_(time_source) { metric_names_for_computing_utilization_ = std::vector( @@ -131,7 +106,7 @@ ClientSideWeightedRoundRobinLoadBalancer::OrcaLoadReportHandler::OrcaLoadReportH } absl::Status ClientSideWeightedRoundRobinLoadBalancer::OrcaLoadReportHandler::onOrcaLoadReport( - const OrcaLoadReport& orca_load_report, const HostDescription& host_description) { + const OrcaLoadReportProto& orca_load_report, const HostDescription& host_description) { const Host* host = dynamic_cast(&host_description); ENVOY_BUG(host != nullptr, "Unable to cast HostDescription to Host."); ENVOY_LOG(trace, @@ -149,12 +124,9 @@ absl::Status ClientSideWeightedRoundRobinLoadBalancer::OrcaLoadReportHandler::on } void ClientSideWeightedRoundRobinLoadBalancer::initFromConfig( - const envoy::extensions::load_balancing_policies::client_side_weighted_round_robin::v3:: - ClientSideWeightedRoundRobin& client_side_weighted_round_robin_config) { - initialize(); - - orca_load_report_handler_ = std::make_shared( - client_side_weighted_round_robin_config, time_source_); + const ClientSideWeightedRoundRobinLbProto& client_side_weighted_round_robin_config) { + ENVOY_LOG(trace, "ClientSideWeightedRoundRobinLbConfig config {}", + client_side_weighted_round_robin_config.DebugString()); blackout_period_ = std::chrono::milliseconds( PROTOBUF_GET_MS_OR_DEFAULT(client_side_weighted_round_robin_config, blackout_period, 10000)); weight_expiration_period_ = std::chrono::milliseconds(PROTOBUF_GET_MS_OR_DEFAULT( @@ -165,9 +137,6 @@ void ClientSideWeightedRoundRobinLoadBalancer::initFromConfig( void ClientSideWeightedRoundRobinLoadBalancer::startWeightUpdatesOnMainThread( Event::Dispatcher& main_thread_dispatcher) { - if (!IS_MAIN_OR_TEST_THREAD()) { - return; - } weight_calculation_timer_ = main_thread_dispatcher.createTimer([this]() -> void { updateWeightsOnMainThread(); weight_calculation_timer_->enableTimer(weight_update_period_); @@ -177,10 +146,8 @@ void ClientSideWeightedRoundRobinLoadBalancer::startWeightUpdatesOnMainThread( void ClientSideWeightedRoundRobinLoadBalancer::updateWeightsOnMainThread() { ENVOY_LOG(trace, "updateWeightsOnMainThread"); - ENVOY_BUG(IS_MAIN_OR_TEST_THREAD(), "Update Weights NOT on MainThread"); - for (uint32_t priority = 0; priority < priority_set_.hostSetsPerPriority().size(); ++priority) { - HostsSource source(priority, HostsSource::SourceType::AllHosts); - updateWeightsOnHosts(hostSourceToHosts(source)); + for (const HostSetPtr& host_set : priority_set_.hostSetsPerPriority()) { + updateWeightsOnHosts(host_set->hosts()); } } @@ -246,7 +213,7 @@ ClientSideWeightedRoundRobinLoadBalancer::getClientSideWeightIfValidFromHost( absl::MutexLock lock(&client_side_data->mu_); // If non_empty_since_ is too recent, we should use the default weight. if (client_side_data->non_empty_since_ > min_non_empty_since) { - ENVOY_LOG(error, + ENVOY_LOG(trace, "Host {} ClientSideHostLbPolicyData non_empty_since_ is too " "recent: {} > {}", getHostAddress(&host), client_side_data->non_empty_since_.time_since_epoch().count(), @@ -266,7 +233,7 @@ ClientSideWeightedRoundRobinLoadBalancer::getClientSideWeightIfValidFromHost( double ClientSideWeightedRoundRobinLoadBalancer::OrcaLoadReportHandler::getUtilizationFromOrcaReport( - const OrcaLoadReport& orca_load_report, + const OrcaLoadReportProto& orca_load_report, const std::vector& metric_names_for_computing_utilization) { // If application_utilization is valid, use it as the utilization metric. double utilization = orca_load_report.application_utilization(); @@ -285,7 +252,7 @@ ClientSideWeightedRoundRobinLoadBalancer::OrcaLoadReportHandler::getUtilizationF absl::StatusOr ClientSideWeightedRoundRobinLoadBalancer::OrcaLoadReportHandler::calculateWeightFromOrcaReport( - const OrcaLoadReport& orca_load_report, + const OrcaLoadReportProto& orca_load_report, const std::vector& metric_names_for_computing_utilization, double error_utilization_penalty) { double qps = orca_load_report.rps_fractional(); @@ -312,7 +279,7 @@ ClientSideWeightedRoundRobinLoadBalancer::OrcaLoadReportHandler::calculateWeight } absl::Status ClientSideWeightedRoundRobinLoadBalancer::OrcaLoadReportHandler:: - updateClientSideDataFromOrcaLoadReport(const OrcaLoadReport& orca_load_report, + updateClientSideDataFromOrcaLoadReport(const OrcaLoadReportProto& orca_load_report, ClientSideHostLbPolicyData& client_side_data) { const absl::StatusOr weight = calculateWeightFromOrcaReport( orca_load_report, metric_names_for_computing_utilization_, error_utilization_penalty_); @@ -331,5 +298,36 @@ absl::Status ClientSideWeightedRoundRobinLoadBalancer::OrcaLoadReportHandler:: return absl::OkStatus(); } +Upstream::LoadBalancerPtr ClientSideWeightedRoundRobinLoadBalancer::WorkerLocalLbFactory::create( + Upstream::LoadBalancerParams params) { + const auto* typed_lb_config = + dynamic_cast(lb_config_.ptr()); + return std::make_unique( + params.priority_set, params.local_priority_set, cluster_info_.lbStats(), runtime_, random_, + cluster_info_.lbConfig(), typed_lb_config->lb_config_, time_source_); +} + +ClientSideWeightedRoundRobinLoadBalancer::ClientSideWeightedRoundRobinLoadBalancer( + OptRef lb_config, const Upstream::ClusterInfo& cluster_info, + const Upstream::PrioritySet& priority_set, Runtime::Loader& runtime, + Envoy::Random::RandomGenerator& random, TimeSource& time_source) + : factory_(std::make_shared(lb_config, cluster_info, priority_set, + runtime, random, time_source)), + lb_config_(lb_config), cluster_info_(cluster_info), priority_set_(priority_set), + runtime_(runtime), random_(random), time_source_(time_source) {} + +absl::Status ClientSideWeightedRoundRobinLoadBalancer::initialize() { + // Ensure that all hosts have client side lb policy data. + for (const HostSetPtr& host_set : priority_set_.hostSetsPerPriority()) { + addClientSideLbPolicyDataToHosts(host_set->hosts()); + } + + const auto* typed_lb_config = + dynamic_cast(lb_config_.ptr()); + initFromConfig(typed_lb_config->lb_config_); + startWeightUpdatesOnMainThread(typed_lb_config->main_thread_dispatcher_); + return absl::OkStatus(); +} + } // namespace Upstream } // namespace Envoy diff --git a/source/extensions/load_balancing_policies/client_side_weighted_round_robin/client_side_weighted_round_robin_lb.h b/source/extensions/load_balancing_policies/client_side_weighted_round_robin/client_side_weighted_round_robin_lb.h index dc34e994885a..1c36f3e702ca 100644 --- a/source/extensions/load_balancing_policies/client_side_weighted_round_robin/client_side_weighted_round_robin_lb.h +++ b/source/extensions/load_balancing_policies/client_side_weighted_round_robin/client_side_weighted_round_robin_lb.h @@ -10,12 +10,30 @@ namespace Envoy { namespace Upstream { +using ClientSideWeightedRoundRobinLbProto = envoy::extensions::load_balancing_policies:: + client_side_weighted_round_robin::v3::ClientSideWeightedRoundRobin; +using OrcaLoadReportProto = xds::data::orca::v3::OrcaLoadReport; + +/** + * Load balancer config used to wrap the config proto. + */ +class ClientSideWeightedRoundRobinLbConfig : public Upstream::LoadBalancerConfig { +public: + ClientSideWeightedRoundRobinLbConfig(const ClientSideWeightedRoundRobinLbProto& lb_config, + Event::Dispatcher& main_thread_dispatcher) + : lb_config_(lb_config), main_thread_dispatcher_(main_thread_dispatcher) {} + + const ClientSideWeightedRoundRobinLbProto lb_config_; + Event::Dispatcher& main_thread_dispatcher_; +}; + /** * A client side weighted round robin load balancer. When in weighted mode, EDF * scheduling is used. When in not weighted mode, simple RR index selection is * used. */ -class ClientSideWeightedRoundRobinLoadBalancer : public EdfLoadBalancerBase { +class ClientSideWeightedRoundRobinLoadBalancer : public Upstream::ThreadAwareLoadBalancer, + protected Logger::Loggable { public: // This struct is used to store the client side data for the host. Hosts are // not shared between different clusters, but are shared between load @@ -43,8 +61,7 @@ class ClientSideWeightedRoundRobinLoadBalancer : public EdfLoadBalancerBase { class OrcaLoadReportHandler : public LoadBalancerContext::OrcaLoadReportCallbacks { public: OrcaLoadReportHandler( - const envoy::extensions::load_balancing_policies::client_side_weighted_round_robin::v3:: - ClientSideWeightedRoundRobin& client_side_weighted_round_robin_config, + const ClientSideWeightedRoundRobinLbProto& client_side_weighted_round_robin_config, TimeSource& time_source); ~OrcaLoadReportHandler() override = default; @@ -52,62 +69,107 @@ class ClientSideWeightedRoundRobinLoadBalancer : public EdfLoadBalancerBase { friend class ClientSideWeightedRoundRobinLoadBalancerFriend; // {LoadBalancerContext::OrcaLoadReportCallbacks} implementation. - absl::Status onOrcaLoadReport(const xds::data::orca::v3::OrcaLoadReport& orca_load_report, + absl::Status onOrcaLoadReport(const OrcaLoadReportProto& orca_load_report, const HostDescription& host_description) override; // Get utilization from `orca_load_report` using named metrics specified in // `metric_names_for_computing_utilization`. static double getUtilizationFromOrcaReport( - const xds::data::orca::v3::OrcaLoadReport& orca_load_report, + const OrcaLoadReportProto& orca_load_report, const std::vector& metric_names_for_computing_utilization); // Calculate client side weight from `orca_load_report` using `getUtilizationFromOrcaReport()`, // QPS, EPS and `error_utilization_penalty`. static absl::StatusOr calculateWeightFromOrcaReport( - const xds::data::orca::v3::OrcaLoadReport& orca_load_report, + const OrcaLoadReportProto& orca_load_report, const std::vector& metric_names_for_computing_utilization, double error_utilization_penalty); // Update client side data from `orca_load_report`. Invoked from `onOrcaLoadReport` callback on // the worker thread. - absl::Status updateClientSideDataFromOrcaLoadReport( - const xds::data::orca::v3::OrcaLoadReport& orca_load_report, - ClientSideHostLbPolicyData& client_side_data); + absl::Status + updateClientSideDataFromOrcaLoadReport(const OrcaLoadReportProto& orca_load_report, + ClientSideHostLbPolicyData& client_side_data); std::vector metric_names_for_computing_utilization_; double error_utilization_penalty_; TimeSource& time_source_; }; - ClientSideWeightedRoundRobinLoadBalancer( - const PrioritySet& priority_set, const PrioritySet* local_priority_set, ClusterLbStats& stats, - Runtime::Loader& runtime, Random::RandomGenerator& random, - const envoy::config::cluster::v3::Cluster::CommonLbConfig& common_config, - const envoy::extensions::load_balancing_policies::client_side_weighted_round_robin::v3:: - ClientSideWeightedRoundRobin& client_side_weighted_round_robin_config, - TimeSource& time_source, Event::Dispatcher& main_thread_dispatcher); + // This class is used to handle the load balancing on the worker thread. + class WorkerLocalLb : public EdfLoadBalancerBase { + public: + WorkerLocalLb( + const PrioritySet& priority_set, const PrioritySet* local_priority_set, + ClusterLbStats& stats, Runtime::Loader& runtime, Random::RandomGenerator& random, + const envoy::config::cluster::v3::Cluster::CommonLbConfig& common_config, + const ClientSideWeightedRoundRobinLbProto& client_side_weighted_round_robin_config, + TimeSource& time_source); -private: - friend class ClientSideWeightedRoundRobinLoadBalancerFriend; + private: + friend class ClientSideWeightedRoundRobinLoadBalancerFriend; + + // {LoadBalancer} Interface implementation. + void refreshHostSource(const HostsSource& source) override; + + HostConstSharedPtr chooseHost(LoadBalancerContext* context) override; - // {LoadBalancer} Interface implementation. - void refreshHostSource(const HostsSource& source) override; + double hostWeight(const Host& host) const override; + HostConstSharedPtr unweightedHostPeek(const HostVector& hosts_to_use, + const HostsSource& source) override; - HostConstSharedPtr chooseHost(LoadBalancerContext* context) override; + HostConstSharedPtr unweightedHostPick(const HostVector& hosts_to_use, + const HostsSource& source) override; - double hostWeight(const Host& host) const override; - HostConstSharedPtr unweightedHostPeek(const HostVector& hosts_to_use, - const HostsSource& source) override; + uint64_t peekahead_index_{}; + absl::flat_hash_map rr_indexes_; + std::shared_ptr orca_load_report_handler_; + }; + + // Factory used to create worker-local load balancer on the worker thread. + class WorkerLocalLbFactory : public Upstream::LoadBalancerFactory { + public: + WorkerLocalLbFactory(OptRef lb_config, + const Upstream::ClusterInfo& cluster_info, + const Upstream::PrioritySet& priority_set, Runtime::Loader& runtime, + Envoy::Random::RandomGenerator& random, TimeSource& time_source) + : lb_config_(lb_config), cluster_info_(cluster_info), priority_set_(priority_set), + runtime_(runtime), random_(random), time_source_(time_source) {} + + Upstream::LoadBalancerPtr create(Upstream::LoadBalancerParams params) override; - HostConstSharedPtr unweightedHostPick(const HostVector& hosts_to_use, - const HostsSource& source) override; + bool recreateOnHostChange() const override { return false; } + + protected: + OptRef lb_config_; + + const Upstream::ClusterInfo& cluster_info_; + const Upstream::PrioritySet& priority_set_; + Runtime::Loader& runtime_; + Envoy::Random::RandomGenerator& random_; + TimeSource& time_source_; + }; - // Initialize LB policy based on the config. +public: + ClientSideWeightedRoundRobinLoadBalancer(OptRef lb_config, + const Upstream::ClusterInfo& cluster_info, + const Upstream::PrioritySet& priority_set, + Runtime::Loader& runtime, + Envoy::Random::RandomGenerator& random, + TimeSource& time_source); + +private: + friend class ClientSideWeightedRoundRobinLoadBalancerFriend; + + // {Upstream::ThreadAwareLoadBalancer} Interface implementation. + Upstream::LoadBalancerFactorySharedPtr factory() override { return factory_; } + absl::Status initialize() override; + + // Initialize LB based on the config. void initFromConfig( - const envoy::extensions::load_balancing_policies::client_side_weighted_round_robin::v3:: - ClientSideWeightedRoundRobin& client_side_weighted_round_robin_config); + const ClientSideWeightedRoundRobinLbProto& client_side_weighted_round_robin_config); - // Start weight updates on main thread only. + // Start weight updates on the main thread. void startWeightUpdatesOnMainThread(Event::Dispatcher& main_thread_dispatcher); // Update weights using client side host LB policy data for all priority sets. @@ -127,9 +189,16 @@ class ClientSideWeightedRoundRobinLoadBalancer : public EdfLoadBalancerBase { getClientSideWeightIfValidFromHost(const Host& host, const MonotonicTime& min_non_empty_since, const MonotonicTime& max_last_update_time); - uint64_t peekahead_index_{}; - absl::flat_hash_map rr_indexes_; - std::shared_ptr orca_load_report_handler_; + // Factory used to create worker-local load balancers on the worker thread. + std::shared_ptr factory_; + // Data that is also passed to the worker-local load balancer via factory_. + OptRef lb_config_; + const Upstream::ClusterInfo& cluster_info_; + const Upstream::PrioritySet& priority_set_; + Runtime::Loader& runtime_; + Envoy::Random::RandomGenerator& random_; + TimeSource& time_source_; + // Timing parameters for the weight update. std::chrono::milliseconds blackout_period_; std::chrono::milliseconds weight_expiration_period_; diff --git a/source/extensions/load_balancing_policies/client_side_weighted_round_robin/config.cc b/source/extensions/load_balancing_policies/client_side_weighted_round_robin/config.cc index c43d65866f4c..0f0c95ee8226 100644 --- a/source/extensions/load_balancing_policies/client_side_weighted_round_robin/config.cc +++ b/source/extensions/load_balancing_policies/client_side_weighted_round_robin/config.cc @@ -7,24 +7,6 @@ namespace Extensions { namespace LoadBalancingPolices { namespace ClientSideWeightedRoundRobin { -ClientSideWeightedRoundRobinLbConfig::ClientSideWeightedRoundRobinLbConfig( - const ClientSideWeightedRoundRobinLbProto& lb_config, - Envoy::Event::Dispatcher& main_thread_dispatcher) - : lb_config_(lb_config), main_thread_dispatcher_(main_thread_dispatcher) {} - -Upstream::LoadBalancerPtr ClientSideWeightedRoundRobinCreator::operator()( - Upstream::LoadBalancerParams params, OptRef lb_config, - const Upstream::ClusterInfo& cluster_info, const Upstream::PrioritySet&, - Runtime::Loader& runtime, Random::RandomGenerator& random, TimeSource& time_source) { - const auto typed_lb_config = - dynamic_cast(lb_config.ptr()); - - return std::make_unique( - params.priority_set, params.local_priority_set, cluster_info.lbStats(), runtime, random, - cluster_info.lbConfig(), typed_lb_config->lb_config_, time_source, - typed_lb_config->main_thread_dispatcher_); -} - /** * Static registration for the Factory. @see RegisterFactory. */ diff --git a/source/extensions/load_balancing_policies/client_side_weighted_round_robin/config.h b/source/extensions/load_balancing_policies/client_side_weighted_round_robin/config.h index 970f57292659..220c6688a3df 100644 --- a/source/extensions/load_balancing_policies/client_side_weighted_round_robin/config.h +++ b/source/extensions/load_balancing_policies/client_side_weighted_round_robin/config.h @@ -5,6 +5,7 @@ #include "envoy/upstream/load_balancer.h" #include "source/common/common/logger.h" +#include "source/extensions/load_balancing_policies/client_side_weighted_round_robin/client_side_weighted_round_robin_lb.h" #include "source/extensions/load_balancing_policies/common/factory_base.h" namespace Envoy { @@ -14,40 +15,30 @@ namespace ClientSideWeightedRoundRobin { using ClientSideWeightedRoundRobinLbProto = envoy::extensions::load_balancing_policies:: client_side_weighted_round_robin::v3::ClientSideWeightedRoundRobin; -using ClusterProto = envoy::config::cluster::v3::Cluster; +// using ClusterProto = envoy::config::cluster::v3::Cluster; -/** - * Load balancer config that used to wrap the proto config. - */ -class ClientSideWeightedRoundRobinLbConfig : public Upstream::LoadBalancerConfig { +class Factory : public Upstream::TypedLoadBalancerFactoryBase { public: - ClientSideWeightedRoundRobinLbConfig(const ClientSideWeightedRoundRobinLbProto& lb_config, - Event::Dispatcher& main_thread_dispatcher); - - const ClientSideWeightedRoundRobinLbProto lb_config_; - Event::Dispatcher& main_thread_dispatcher_; -}; - -struct ClientSideWeightedRoundRobinCreator : public Logger::Loggable { - Upstream::LoadBalancerPtr operator()(Upstream::LoadBalancerParams params, - OptRef lb_config, - const Upstream::ClusterInfo& cluster_info, - const Upstream::PrioritySet& priority_set, - Runtime::Loader& runtime, Random::RandomGenerator& random, - TimeSource& time_source); -}; - -class Factory : public Common::FactoryBase { -public: - Factory() : FactoryBase("envoy.load_balancing_policies.client_side_weighted_round_robin") {} + Factory() + : Upstream::TypedLoadBalancerFactoryBase( + "envoy.load_balancing_policies.client_side_weighted_round_robin") {} + + Upstream::ThreadAwareLoadBalancerPtr create(OptRef lb_config, + const Upstream::ClusterInfo& cluster_info, + const Upstream::PrioritySet& priority_set, + Runtime::Loader& runtime, + Envoy::Random::RandomGenerator& random, + TimeSource& time_source) override { + return std::make_unique( + lb_config, cluster_info, priority_set, runtime, random, time_source); + } Upstream::LoadBalancerConfigPtr loadConfig(Upstream::LoadBalancerFactoryContext& context, const Protobuf::Message& config, ProtobufMessage::ValidationVisitor&) override { const auto& lb_config = dynamic_cast(config); - return Upstream::LoadBalancerConfigPtr{ - new ClientSideWeightedRoundRobinLbConfig(lb_config, context.mainThreadDispatcher())}; + return Upstream::LoadBalancerConfigPtr{new Upstream::ClientSideWeightedRoundRobinLbConfig( + lb_config, context.mainThreadDispatcher())}; } }; diff --git a/test/extensions/load_balancing_policies/client_side_weighted_round_robin/client_side_weighted_round_robin_lb_test.cc b/test/extensions/load_balancing_policies/client_side_weighted_round_robin/client_side_weighted_round_robin_lb_test.cc index ff1863a9e93f..97d152a1ece7 100644 --- a/test/extensions/load_balancing_policies/client_side_weighted_round_robin/client_side_weighted_round_robin_lb_test.cc +++ b/test/extensions/load_balancing_policies/client_side_weighted_round_robin/client_side_weighted_round_robin_lb_test.cc @@ -17,17 +17,22 @@ namespace Upstream { class ClientSideWeightedRoundRobinLoadBalancerFriend { public: explicit ClientSideWeightedRoundRobinLoadBalancerFriend( - std::shared_ptr lb) - : lb_(std::move(lb)) {} + std::shared_ptr lb, + std::shared_ptr worker_lb) + : lb_(std::move(lb)), worker_lb_(std::move(worker_lb)) {} ~ClientSideWeightedRoundRobinLoadBalancerFriend() = default; - HostConstSharedPtr chooseHost(LoadBalancerContext* context) { return lb_->chooseHost(context); } + HostConstSharedPtr chooseHost(LoadBalancerContext* context) { + return worker_lb_->chooseHost(context); + } HostConstSharedPtr peekAnotherHost(LoadBalancerContext* context) { - return lb_->peekAnotherHost(context); + return worker_lb_->peekAnotherHost(context); } + absl::Status initialize() { return lb_->initialize(); } + void updateWeightsOnMainThread() { lb_->updateWeightsOnMainThread(); } void updateWeightsOnHosts(const HostVector& hosts) { lb_->updateWeightsOnHosts(hosts); } @@ -62,12 +67,13 @@ class ClientSideWeightedRoundRobinLoadBalancerFriend { absl::Status updateClientSideDataFromOrcaLoadReport( const xds::data::orca::v3::OrcaLoadReport& orca_load_report, ClientSideWeightedRoundRobinLoadBalancer::ClientSideHostLbPolicyData& client_side_data) { - return lb_->orca_load_report_handler_->updateClientSideDataFromOrcaLoadReport(orca_load_report, - client_side_data); + return worker_lb_->orca_load_report_handler_->updateClientSideDataFromOrcaLoadReport( + orca_load_report, client_side_data); } private: std::shared_ptr lb_; + std::shared_ptr worker_lb_; }; namespace { @@ -111,8 +117,13 @@ class ClientSideWeightedRoundRobinLoadBalancerTest : public LoadBalancerTestBase lb_ = std::make_shared( std::make_shared( + lb_config_, cluster_info_, priority_set_, runtime_, random_, simTime()), + std::make_shared( priority_set_, local_priority_set_.get(), stats_, runtime_, random_, common_config_, - client_side_weighted_round_robin_config_, simTime(), dispatcher_)); + client_side_weighted_round_robin_config_, simTime())); + + // Initialize the thread aware load balancer from config. + ASSERT_EQ(lb_->initialize(), absl::OkStatus()); } // Updates priority 0 with the given hosts and hosts_per_locality. @@ -136,6 +147,7 @@ class ClientSideWeightedRoundRobinLoadBalancerTest : public LoadBalancerTestBase envoy::extensions::load_balancing_policies::client_side_weighted_round_robin::v3:: ClientSideWeightedRoundRobin client_side_weighted_round_robin_config_; + std::shared_ptr local_priority_set_; std::shared_ptr lb_; HostsPerLocalityConstSharedPtr empty_locality_; @@ -143,6 +155,9 @@ class ClientSideWeightedRoundRobinLoadBalancerTest : public LoadBalancerTestBase NiceMock lb_context_; NiceMock dispatcher_; + NiceMock cluster_info_; + ClientSideWeightedRoundRobinLbConfig lb_config_ = + ClientSideWeightedRoundRobinLbConfig(client_side_weighted_round_robin_config_, dispatcher_); }; ////////////////////////////////////////////////////// @@ -618,8 +633,10 @@ TEST_P(ClientSideWeightedRoundRobinLoadBalancerTest, WeightedInitializationPicks EXPECT_CALL(random_, random()).Times(2).WillRepeatedly(Return(i)); ClientSideWeightedRoundRobinLoadBalancerFriend lb( std::make_shared( + lb_config_, cluster_info_, priority_set_, runtime_, random_, simTime()), + std::make_shared( priority_set_, local_priority_set_.get(), stats_, runtime_, random_, common_config_, - client_side_weighted_round_robin_config_, simTime(), dispatcher_)); + client_side_weighted_round_robin_config_, simTime())); const auto& host = lb.chooseHost(nullptr); host_picked_count_map[host]++; }