diff --git a/pkg/portfwd/client.go b/pkg/portfwd/client.go index 38a645a4a69..8eccfa7648a 100644 --- a/pkg/portfwd/client.go +++ b/pkg/portfwd/client.go @@ -10,39 +10,38 @@ import ( "github.com/lima-vm/lima/pkg/guestagent/api" guestagentclient "github.com/lima-vm/lima/pkg/guestagent/api/client" "github.com/sirupsen/logrus" + "golang.org/x/sync/errgroup" ) func HandleTCPConnection(ctx context.Context, client *guestagentclient.GuestAgentClient, conn net.Conn, guestAddr string) { defer conn.Close() id := fmt.Sprintf("tcp-%s-%s", conn.LocalAddr().String(), conn.RemoteAddr().String()) - errCh := make(chan error, 2) stream, err := client.Tunnel(ctx) if err != nil { logrus.Errorf("could not open tcp tunnel for id: %s error:%v", id, err) } + g, _ := errgroup.WithContext(ctx) + rw := &GrpcClientRW{stream: stream, id: id, addr: guestAddr} - go func() { + g.Go(func() error { _, err := io.Copy(rw, conn) if errors.Is(err, io.EOF) { - errCh <- nil - return + return nil } - errCh <- err - }() - go func() { + return err + }) + g.Go(func() error { _, err := io.Copy(conn, rw) if errors.Is(err, io.EOF) { - errCh <- nil - return + return nil } - errCh <- err - }() + return err + }) - err = <-errCh - if err != nil { + if err := g.Wait(); err != nil { logrus.Debugf("error in tcp tunnel for id: %s error:%v", id, err) } } @@ -57,19 +56,17 @@ func HandleUDPConnection(ctx context.Context, client *guestagentclient.GuestAgen logrus.Errorf("could not open udp tunnel for id: %s error:%v", id, err) } - errCh := make(chan error, 2) + g, _ := errgroup.WithContext(ctx) - go func() { + g.Go(func() error { buf := make([]byte, 65507) for { n, addr, err := conn.ReadFrom(buf) if errors.Is(err, io.EOF) { - errCh <- nil - return + return nil } if err != nil { - errCh <- err - return + return err } msg := &api.TunnelMessage{ Id: id + "-" + addr.String(), @@ -79,38 +76,32 @@ func HandleUDPConnection(ctx context.Context, client *guestagentclient.GuestAgen UdpTargetAddr: addr.String(), } if err := stream.Send(msg); err != nil { - errCh <- err - return + return err } } - }() + }) - go func() { + g.Go(func() error { for { in, err := stream.Recv() if errors.Is(err, io.EOF) { - errCh <- nil - return + return nil } if err != nil { - errCh <- err - return + return err } addr, err := net.ResolveUDPAddr("udp", in.UdpTargetAddr) if err != nil { - errCh <- err - return + return err } _, err = conn.WriteTo(in.Data, addr) if err != nil { - errCh <- err - return + return err } } - }() + }) - err = <-errCh - if err != nil { + if err := g.Wait(); err != nil { logrus.Debugf("error in udp tunnel for id: %s error:%v", id, err) } }