Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: kpango <kpango@vdaas.org>
  • Loading branch information
kpango committed Oct 28, 2024
1 parent f7d704f commit dd54b18
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 61 deletions.
97 changes: 46 additions & 51 deletions internal/client/v1/client/discoverer/discover.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
package discoverer

import (
"cmp"
"context"
"reflect"
"slices"
"sync/atomic"
"time"

Expand Down Expand Up @@ -96,10 +98,8 @@ func (c *client) Start(ctx context.Context) (<-chan error, error) {
}
}

ech := make(chan error, 100)
addrs, err := c.dnsDiscovery(ctx, ech)
addrs, err := c.dnsDiscovery(ctx)
if err != nil {
close(ech)
return nil, err
}
c.addrs.Store(&addrs)
Expand All @@ -116,20 +116,17 @@ func (c *client) Start(ctx context.Context) (<-chan error, error) {
if c.client != nil {
aech, err = c.client.StartConnectionMonitor(ctx)
if err != nil {
close(ech)
return nil, err
}
}
}

err = c.discover(ctx, ech)
err = c.discover(ctx)
if err != nil {
close(ech)
return nil, errors.Join(c.dscClient.Close(ctx), err)
}

ech := make(chan error, 100)
c.eg.Go(safety.RecoverFunc(func() (err error) {
defer close(ech)
dt := time.NewTicker(c.dscDur)
defer dt.Stop()
finalize := func() (err error) {
Expand Down Expand Up @@ -158,7 +155,7 @@ func (c *client) Start(ctx context.Context) (<-chan error, error) {
case err = <-aech:
case err = <-rrech:
case <-dt.C:
err = c.discover(ctx, ech)
err = c.discover(ctx)
}
if err != nil {
log.Error(err)
Expand All @@ -177,14 +174,11 @@ func (c *client) Start(ctx context.Context) (<-chan error, error) {
func (c *client) GetAddrs(ctx context.Context) (addrs []string) {
a := c.addrs.Load()
if a == nil {
ips, err := net.DefaultResolver.LookupIPAddr(ctx, c.dns)
var err error
addrs, err = c.dnsDiscovery(ctx)
if err != nil {
return nil
}
addrs = make([]string, 0, len(ips))
for _, ip := range ips {
addrs = append(addrs, ip.String())
}
} else {
addrs = *a
}
Expand Down Expand Up @@ -238,7 +232,7 @@ func (c *client) disconnect(ctx context.Context, addr string) (err error) {
return
}

func (c *client) dnsDiscovery(ctx context.Context, ech chan<- error) (addrs []string, err error) {
func (c *client) dnsDiscovery(ctx context.Context) (addrs []string, err error) {
ips, err := net.DefaultResolver.LookupIPAddr(ctx, c.dns)
if err != nil || len(ips) == 0 {
return nil, errors.ErrAddrCouldNotDiscover(err, c.dns)
Expand All @@ -249,7 +243,6 @@ func (c *client) dnsDiscovery(ctx context.Context, ech chan<- error) (addrs []st
addr := net.JoinHostPort(ip.String(), uint16(c.port))
if err = c.connect(ctx, addr); err != nil {
log.Debugf("dns discovery connect for addr = %s from dns = %s failed %v", addr, c.dns, err)
ech <- err
} else {
log.Debugf("dns discovery connect for addr = %s from dns = %s succeeded", addr, c.dns)
addrs = append(addrs, addr)
Expand All @@ -264,15 +257,15 @@ func (c *client) dnsDiscovery(ctx context.Context, ech chan<- error) (addrs []st
return addrs, nil
}

func (c *client) discover(ctx context.Context, ech chan<- error) (err error) {
func (c *client) discover(ctx context.Context) (err error) {
if c.dscClient == nil || (c.autoconn && c.client == nil) {
return errors.ErrGRPCClientNotFound
}

var connected []string
if bo := c.client.GetBackoff(); bo != nil {
_, err = bo.Do(ctx, func(ctx context.Context) (any, bool, error) {
connected, err = c.updateDiscoveryInfo(ctx, ech)
connected, err = c.updateDiscoveryInfo(ctx)
if err != nil {
if !errors.Is(err, errors.ErrGRPCClientNotFound) &&
!errors.Is(err, errors.ErrGRPCClientConnNotFound("*")) {
Expand All @@ -283,24 +276,22 @@ func (c *client) discover(ctx context.Context, ech chan<- error) (err error) {
return nil, false, nil
})
} else {
connected, err = c.updateDiscoveryInfo(ctx, ech)
connected, err = c.updateDiscoveryInfo(ctx)
}
if err != nil {
log.Warnf("failed to discover addrs from discoverer API, error: %v,\ttrying to dns discovery from %s...", err, c.dns)
connected, err = c.dnsDiscovery(ctx, ech)
connected, err = c.dnsDiscovery(ctx)
if err != nil {
return err
}
}

oldAddrs := c.GetAddrs(ctx)
c.addrs.Store(&connected)
return c.disconnectOldAddrs(ctx, oldAddrs, connected, ech)
return c.disconnectOldAddrs(ctx, oldAddrs, connected)
}

func (c *client) updateDiscoveryInfo(
ctx context.Context, ech chan<- error,
) (connected []string, err error) {
func (c *client) updateDiscoveryInfo(ctx context.Context) (connected []string, err error) {
nodes, err := c.discoverNodes(ctx)
if err != nil {
log.Warnf("error detected when discovering nodes,\terrors: %v", err)
Expand All @@ -310,7 +301,7 @@ func (c *client) updateDiscoveryInfo(
log.Warn("no nodes found")
return nil, errors.ErrNodeNotFound("all")
}
connected, err = c.discoverAddrs(ctx, nodes, ech)
connected, err = c.discoverAddrs(ctx, nodes)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -343,21 +334,39 @@ func (c *client) discoverNodes(ctx context.Context) (nodes *payload.Info_Nodes,
}
return nodes, nil
})
return nodes, err
if err != nil {
return nil, err
}
slices.SortFunc(nodes.Nodes, func(left, right *payload.Info_Node) int {
return cmp.Compare(left.GetMemory().GetUsage(), right.GetMemory().GetUsage())
})
return nodes, nil
}

func (c *client) discoverAddrs(
ctx context.Context, nodes *payload.Info_Nodes, ech chan<- error,
ctx context.Context, nodes *payload.Info_Nodes,
) (addrs []string, err error) {
if nodes == nil {
return nil, errors.ErrAddrCouldNotDiscover(err, c.dns)
}
maxPodLen := 0
podLength := 0
for _, node := range nodes.GetNodes() {
l := len(node.GetPods().GetPods())
podLength += l
if l > maxPodLen {
maxPodLen = l
for i, node := range nodes.GetNodes() {
if node != nil && node.GetPods() != nil && node.GetPods().GetPods() != nil {
l := len(node.GetPods().GetPods())
podLength += l
if l > maxPodLen {
maxPodLen = l
}
slices.SortFunc(nodes.Nodes[i].Pods.Pods, func(left, right *payload.Info_Pod) int {
return cmp.Compare(left.GetMemory().GetUsage(), right.GetMemory().GetUsage())
})
}
}
nbody, err := nodes.MarshalJSON()
if err != nil {
log.Debug(string(nbody))
}
addrs = make([]string, 0, podLength)
for i := 0; i < maxPodLen; i++ {
for _, node := range nodes.GetNodes() {
Expand All @@ -371,11 +380,7 @@ func (c *client) discoverAddrs(
len(node.GetPods().GetPods()[i].GetIp()) != 0 {
addr := net.JoinHostPort(node.GetPods().GetPods()[i].GetIp(), uint16(c.port))
if err = c.connect(ctx, addr); err != nil {
select {
case <-ctx.Done():
return nil, ctx.Err()
case ech <- errors.ErrAddrCouldNotDiscover(err, addr):
}
log.Error(errors.ErrAddrCouldNotDiscover(err, addr))
err = nil
} else {
addrs = append(addrs, addr)
Expand All @@ -388,7 +393,7 @@ func (c *client) discoverAddrs(
}

func (c *client) disconnectOldAddrs(
ctx context.Context, oldAddrs, connectedAddrs []string, ech chan<- error,
ctx context.Context, oldAddrs, connectedAddrs []string,
) (err error) {
if !c.autoconn {
return nil
Expand All @@ -404,7 +409,7 @@ func (c *client) disconnectOldAddrs(
c.eg.Go(safety.RecoverFunc(func() error {
err = c.disconnect(ctx, old)
if err != nil {
ech <- err
log.Error(err)
}
return nil
}))
Expand All @@ -420,22 +425,12 @@ func (c *client) disconnectOldAddrs(
if !ok {
err = c.disconnect(ctx, addr)
if err != nil {
select {
case <-ctx.Done():
return errors.Join(ctx.Err(), err)
case ech <- err:
return err
}
return err
}
}
return nil
}); err != nil {
select {
case <-ctx.Done():
return errors.Join(ctx.Err(), err)
case ech <- err:
return err
}
log.Error(err)
}
}
return nil
Expand Down
10 changes: 8 additions & 2 deletions internal/net/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ func NewDialer(opts ...DialerOption) (der Dialer, err error) {

netpoll.SetLoadBalance(netpoll.RoundRobin)
d.npDialer = netpoll.NewDialer()

d.dialer = d.dial
if d.enableDNSCache {
if d.dnsRefreshDuration > d.dnsCacheExpiration {
Expand Down Expand Up @@ -330,7 +329,14 @@ func (d *dialer) dial(ctx context.Context, network, addr string) (conn Conn, err
} else {
conn, err = d.npDialer.DialConnection(network, addr, d.der.Timeout)
if err != nil {
conn, err = d.der.DialContext(ctx, network, addr)
if conn != nil {
err = errors.Join(err, conn.Close())
}
var ierr error
conn, ierr = d.der.DialContext(ctx, network, addr)
if ierr != nil {
err = errors.Join(err, ierr)
}
}
}
return err
Expand Down
18 changes: 10 additions & 8 deletions pkg/discoverer/k8s/service/discover.go
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,6 @@ func (d *discoverer) GetPods(req *payload.Discoverer_Request) (pods *payload.Inf
func (d *discoverer) GetNodes(
req *payload.Discoverer_Request,
) (nodes *payload.Info_Nodes, err error) {
nodes = new(payload.Info_Nodes)
nbn, ok := d.nodeByName.Load().(map[string]*payload.Info_Node)
if !ok {
return nil, errors.ErrInvalidDiscoveryCache
Expand All @@ -527,10 +526,15 @@ func (d *discoverer) GetNodes(
if err == nil {
n.Pods = ps
}
nodes.Nodes = append(nodes.GetNodes(), n)
return nodes, nil
return &payload.Info_Nodes{
Nodes: []*payload.Info_Node{
n,
},
}, nil
}
nodes = &payload.Info_Nodes{
Nodes: make([]*payload.Info_Node, len(nbn)),
}
ns := nodes.Nodes
for name, n := range nbn {
req.Node = name
if n.GetPods() != nil {
Expand All @@ -546,13 +550,11 @@ func (d *discoverer) GetNodes(
n.Pods = ps
}
}
ns = append(ns, n)
nodes.Nodes = append(nodes.Nodes, n)
}
slices.SortFunc(ns, func(left, right *payload.Info_Node) int {
slices.SortFunc(nodes.Nodes, func(left, right *payload.Info_Node) int {
return cmp.Compare(left.GetMemory().GetUsage(), right.GetMemory().GetUsage())
})

nodes.Nodes = ns
return nodes, nil
}

Expand Down

0 comments on commit dd54b18

Please sign in to comment.