From dd54b181f2f98f7aa25034ffb053147cacb97c25 Mon Sep 17 00:00:00 2001 From: kpango Date: Sun, 27 Oct 2024 13:02:04 +0900 Subject: [PATCH] fix Signed-off-by: kpango --- .../client/v1/client/discoverer/discover.go | 97 +++++++++---------- internal/net/dialer.go | 10 +- pkg/discoverer/k8s/service/discover.go | 18 ++-- 3 files changed, 64 insertions(+), 61 deletions(-) diff --git a/internal/client/v1/client/discoverer/discover.go b/internal/client/v1/client/discoverer/discover.go index 83c98c46db2..0468da52886 100644 --- a/internal/client/v1/client/discoverer/discover.go +++ b/internal/client/v1/client/discoverer/discover.go @@ -18,8 +18,10 @@ package discoverer import ( + "cmp" "context" "reflect" + "slices" "sync/atomic" "time" @@ -96,10 +98,8 @@ func (c *client) Start(ctx context.Context) (<-chan error, error) { } } - ech := make(chan error, 100) - addrs, err := c.dnsDiscovery(ctx, ech) + addrs, err := c.dnsDiscovery(ctx) if err != nil { - close(ech) return nil, err } c.addrs.Store(&addrs) @@ -116,20 +116,17 @@ func (c *client) Start(ctx context.Context) (<-chan error, error) { if c.client != nil { aech, err = c.client.StartConnectionMonitor(ctx) if err != nil { - close(ech) return nil, err } } } - err = c.discover(ctx, ech) + err = c.discover(ctx) if err != nil { - close(ech) return nil, errors.Join(c.dscClient.Close(ctx), err) } - + ech := make(chan error, 100) c.eg.Go(safety.RecoverFunc(func() (err error) { - defer close(ech) dt := time.NewTicker(c.dscDur) defer dt.Stop() finalize := func() (err error) { @@ -158,7 +155,7 @@ func (c *client) Start(ctx context.Context) (<-chan error, error) { case err = <-aech: case err = <-rrech: case <-dt.C: - err = c.discover(ctx, ech) + err = c.discover(ctx) } if err != nil { log.Error(err) @@ -177,14 +174,11 @@ func (c *client) Start(ctx context.Context) (<-chan error, error) { func (c *client) GetAddrs(ctx context.Context) (addrs []string) { a := c.addrs.Load() if a == nil { - ips, err := net.DefaultResolver.LookupIPAddr(ctx, c.dns) + var err error + addrs, err = c.dnsDiscovery(ctx) if err != nil { return nil } - addrs = make([]string, 0, len(ips)) - for _, ip := range ips { - addrs = append(addrs, ip.String()) - } } else { addrs = *a } @@ -238,7 +232,7 @@ func (c *client) disconnect(ctx context.Context, addr string) (err error) { return } -func (c *client) dnsDiscovery(ctx context.Context, ech chan<- error) (addrs []string, err error) { +func (c *client) dnsDiscovery(ctx context.Context) (addrs []string, err error) { ips, err := net.DefaultResolver.LookupIPAddr(ctx, c.dns) if err != nil || len(ips) == 0 { return nil, errors.ErrAddrCouldNotDiscover(err, c.dns) @@ -249,7 +243,6 @@ func (c *client) dnsDiscovery(ctx context.Context, ech chan<- error) (addrs []st addr := net.JoinHostPort(ip.String(), uint16(c.port)) if err = c.connect(ctx, addr); err != nil { log.Debugf("dns discovery connect for addr = %s from dns = %s failed %v", addr, c.dns, err) - ech <- err } else { log.Debugf("dns discovery connect for addr = %s from dns = %s succeeded", addr, c.dns) addrs = append(addrs, addr) @@ -264,7 +257,7 @@ func (c *client) dnsDiscovery(ctx context.Context, ech chan<- error) (addrs []st return addrs, nil } -func (c *client) discover(ctx context.Context, ech chan<- error) (err error) { +func (c *client) discover(ctx context.Context) (err error) { if c.dscClient == nil || (c.autoconn && c.client == nil) { return errors.ErrGRPCClientNotFound } @@ -272,7 +265,7 @@ func (c *client) discover(ctx context.Context, ech chan<- error) (err error) { var connected []string if bo := c.client.GetBackoff(); bo != nil { _, err = bo.Do(ctx, func(ctx context.Context) (any, bool, error) { - connected, err = c.updateDiscoveryInfo(ctx, ech) + connected, err = c.updateDiscoveryInfo(ctx) if err != nil { if !errors.Is(err, errors.ErrGRPCClientNotFound) && !errors.Is(err, errors.ErrGRPCClientConnNotFound("*")) { @@ -283,11 +276,11 @@ func (c *client) discover(ctx context.Context, ech chan<- error) (err error) { return nil, false, nil }) } else { - connected, err = c.updateDiscoveryInfo(ctx, ech) + connected, err = c.updateDiscoveryInfo(ctx) } if err != nil { log.Warnf("failed to discover addrs from discoverer API, error: %v,\ttrying to dns discovery from %s...", err, c.dns) - connected, err = c.dnsDiscovery(ctx, ech) + connected, err = c.dnsDiscovery(ctx) if err != nil { return err } @@ -295,12 +288,10 @@ func (c *client) discover(ctx context.Context, ech chan<- error) (err error) { oldAddrs := c.GetAddrs(ctx) c.addrs.Store(&connected) - return c.disconnectOldAddrs(ctx, oldAddrs, connected, ech) + return c.disconnectOldAddrs(ctx, oldAddrs, connected) } -func (c *client) updateDiscoveryInfo( - ctx context.Context, ech chan<- error, -) (connected []string, err error) { +func (c *client) updateDiscoveryInfo(ctx context.Context) (connected []string, err error) { nodes, err := c.discoverNodes(ctx) if err != nil { log.Warnf("error detected when discovering nodes,\terrors: %v", err) @@ -310,7 +301,7 @@ func (c *client) updateDiscoveryInfo( log.Warn("no nodes found") return nil, errors.ErrNodeNotFound("all") } - connected, err = c.discoverAddrs(ctx, nodes, ech) + connected, err = c.discoverAddrs(ctx, nodes) if err != nil { return nil, err } @@ -343,21 +334,39 @@ func (c *client) discoverNodes(ctx context.Context) (nodes *payload.Info_Nodes, } return nodes, nil }) - return nodes, err + if err != nil { + return nil, err + } + slices.SortFunc(nodes.Nodes, func(left, right *payload.Info_Node) int { + return cmp.Compare(left.GetMemory().GetUsage(), right.GetMemory().GetUsage()) + }) + return nodes, nil } func (c *client) discoverAddrs( - ctx context.Context, nodes *payload.Info_Nodes, ech chan<- error, + ctx context.Context, nodes *payload.Info_Nodes, ) (addrs []string, err error) { + if nodes == nil { + return nil, errors.ErrAddrCouldNotDiscover(err, c.dns) + } maxPodLen := 0 podLength := 0 - for _, node := range nodes.GetNodes() { - l := len(node.GetPods().GetPods()) - podLength += l - if l > maxPodLen { - maxPodLen = l + for i, node := range nodes.GetNodes() { + if node != nil && node.GetPods() != nil && node.GetPods().GetPods() != nil { + l := len(node.GetPods().GetPods()) + podLength += l + if l > maxPodLen { + maxPodLen = l + } + slices.SortFunc(nodes.Nodes[i].Pods.Pods, func(left, right *payload.Info_Pod) int { + return cmp.Compare(left.GetMemory().GetUsage(), right.GetMemory().GetUsage()) + }) } } + nbody, err := nodes.MarshalJSON() + if err != nil { + log.Debug(string(nbody)) + } addrs = make([]string, 0, podLength) for i := 0; i < maxPodLen; i++ { for _, node := range nodes.GetNodes() { @@ -371,11 +380,7 @@ func (c *client) discoverAddrs( len(node.GetPods().GetPods()[i].GetIp()) != 0 { addr := net.JoinHostPort(node.GetPods().GetPods()[i].GetIp(), uint16(c.port)) if err = c.connect(ctx, addr); err != nil { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case ech <- errors.ErrAddrCouldNotDiscover(err, addr): - } + log.Error(errors.ErrAddrCouldNotDiscover(err, addr)) err = nil } else { addrs = append(addrs, addr) @@ -388,7 +393,7 @@ func (c *client) discoverAddrs( } func (c *client) disconnectOldAddrs( - ctx context.Context, oldAddrs, connectedAddrs []string, ech chan<- error, + ctx context.Context, oldAddrs, connectedAddrs []string, ) (err error) { if !c.autoconn { return nil @@ -404,7 +409,7 @@ func (c *client) disconnectOldAddrs( c.eg.Go(safety.RecoverFunc(func() error { err = c.disconnect(ctx, old) if err != nil { - ech <- err + log.Error(err) } return nil })) @@ -420,22 +425,12 @@ func (c *client) disconnectOldAddrs( if !ok { err = c.disconnect(ctx, addr) if err != nil { - select { - case <-ctx.Done(): - return errors.Join(ctx.Err(), err) - case ech <- err: - return err - } + return err } } return nil }); err != nil { - select { - case <-ctx.Done(): - return errors.Join(ctx.Err(), err) - case ech <- err: - return err - } + log.Error(err) } } return nil diff --git a/internal/net/dialer.go b/internal/net/dialer.go index 73ca233f504..2a5b23421d1 100644 --- a/internal/net/dialer.go +++ b/internal/net/dialer.go @@ -119,7 +119,6 @@ func NewDialer(opts ...DialerOption) (der Dialer, err error) { netpoll.SetLoadBalance(netpoll.RoundRobin) d.npDialer = netpoll.NewDialer() - d.dialer = d.dial if d.enableDNSCache { if d.dnsRefreshDuration > d.dnsCacheExpiration { @@ -330,7 +329,14 @@ func (d *dialer) dial(ctx context.Context, network, addr string) (conn Conn, err } else { conn, err = d.npDialer.DialConnection(network, addr, d.der.Timeout) if err != nil { - conn, err = d.der.DialContext(ctx, network, addr) + if conn != nil { + err = errors.Join(err, conn.Close()) + } + var ierr error + conn, ierr = d.der.DialContext(ctx, network, addr) + if ierr != nil { + err = errors.Join(err, ierr) + } } } return err diff --git a/pkg/discoverer/k8s/service/discover.go b/pkg/discoverer/k8s/service/discover.go index c95b02420dd..f9653ee2892 100644 --- a/pkg/discoverer/k8s/service/discover.go +++ b/pkg/discoverer/k8s/service/discover.go @@ -513,7 +513,6 @@ func (d *discoverer) GetPods(req *payload.Discoverer_Request) (pods *payload.Inf func (d *discoverer) GetNodes( req *payload.Discoverer_Request, ) (nodes *payload.Info_Nodes, err error) { - nodes = new(payload.Info_Nodes) nbn, ok := d.nodeByName.Load().(map[string]*payload.Info_Node) if !ok { return nil, errors.ErrInvalidDiscoveryCache @@ -527,10 +526,15 @@ func (d *discoverer) GetNodes( if err == nil { n.Pods = ps } - nodes.Nodes = append(nodes.GetNodes(), n) - return nodes, nil + return &payload.Info_Nodes{ + Nodes: []*payload.Info_Node{ + n, + }, + }, nil + } + nodes = &payload.Info_Nodes{ + Nodes: make([]*payload.Info_Node, len(nbn)), } - ns := nodes.Nodes for name, n := range nbn { req.Node = name if n.GetPods() != nil { @@ -546,13 +550,11 @@ func (d *discoverer) GetNodes( n.Pods = ps } } - ns = append(ns, n) + nodes.Nodes = append(nodes.Nodes, n) } - slices.SortFunc(ns, func(left, right *payload.Info_Node) int { + slices.SortFunc(nodes.Nodes, func(left, right *payload.Info_Node) int { return cmp.Compare(left.GetMemory().GetUsage(), right.GetMemory().GetUsage()) }) - - nodes.Nodes = ns return nodes, nil }