diff --git a/cmd/webhook/main.go b/cmd/webhook/main.go index 2f1869a..2c159a9 100644 --- a/cmd/webhook/main.go +++ b/cmd/webhook/main.go @@ -3,20 +3,19 @@ package main import ( "context" "crypto/tls" - "errors" "flag" "log" - "net/http" "os" "os/signal" + "sync" "syscall" - "time" "github.com/prometheus/client_golang/prometheus" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/serializer" "github.com/orange-cloudavenue/kube-image-updater/internal/health" + "github.com/orange-cloudavenue/kube-image-updater/internal/httpserver" client "github.com/orange-cloudavenue/kube-image-updater/internal/kubeclient" "github.com/orange-cloudavenue/kube-image-updater/internal/metrics" ) @@ -60,24 +59,25 @@ func init() { if os.Getenv("POD_NAMESPACE") != "" { webhookNamespace = os.Getenv("POD_NAMESPACE") } + // init flags + flag.StringVar(&webhookPort, "webhook-port", webhookPort, "Webhook server port.ex: :8443") + flag.StringVar(&webhookNamespace, "namespace", webhookNamespace, "Kimup Webhook Mutating namespace.") + flag.StringVar(&webhookServiceName, "service-name", webhookServiceName, "Kimup Webhook Mutating service name.") + flag.BoolVar(&insideCluster, "inside-cluster", true, "True if running inside k8s cluster.") + flag.Parse() } // Start http server for webhook func main() { + // !-- Context --! // ctx, cancel := context.WithCancel(context.Background()) defer cancel() + wg := sync.WaitGroup{} signalChan := make(chan os.Signal, 1) signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGKILL) var err error - flag.StringVar(&webhookPort, "webhook-port", webhookPort, "Webhook server port.ex: :8443") - flag.StringVar(&webhookNamespace, "namespace", webhookNamespace, "Kimup Webhook Mutating namespace.") - flag.StringVar(&webhookServiceName, "service-name", webhookServiceName, "Kimup Webhook Mutating service name.") - - flag.BoolVar(&insideCluster, "inside-cluster", true, "True if running inside k8s cluster.") - - flag.Parse() // homedir for kubeconfig homedir, err := os.UserHomeDir() @@ -100,48 +100,42 @@ func main() { signalChan <- os.Interrupt } - mux := http.NewServeMux() - mux.HandleFunc(webhookPathMutate, serveHandler) - - // define http server and server handler - s := &http.Server{ - Addr: webhookPort, - Handler: mux, - ReadTimeout: 10 * time.Second, - TLSConfig: &tls.Config{ - Certificates: []tls.Certificate{pair}, - MinVersion: tls.VersionTLS12, - // InsecureSkipVerify: true, //nolint:gosec - }, + // Start the webhook server + wg.Add(1) + if err := StarWebhook(ctx, &wg, &tls.Config{ + Certificates: []tls.Certificate{pair}, + MinVersion: tls.VersionTLS12, + // InsecureSkipVerify: true, //nolint:gosec + }); err != nil { + errorLogger.Fatalf("Failed to start webhook server: %v", err) } - // start the HTTP server - go func() { - infoLogger.Printf("Starting webhook server on %s from insideCluster=%v", s.Addr, insideCluster) - if err = s.ListenAndServeTLS("", ""); !errors.Is(err, http.ErrServerClosed) { - log.Fatalf("Failed to start webhook server: %v", err) - } else { - log.Printf("Shutting down webhook server on %s", s.Addr) - } - }() - // !-- Prometheus metrics server --! // // start the metrics server - if err := metrics.ServeProm(ctx); err != nil { + if err := metrics.StartProm(ctx, &wg); err != nil { errorLogger.Fatalf("Failed to start metrics server: %v", err) } // !-- Health check server --! // // start the health check server - if err := health.ServeHealth(ctx); err != nil { + if err := health.StartHealth(ctx, &wg); err != nil { errorLogger.Fatalf("Failed to start health check server: %v", err) } // !-- OS signal handling --! // // listening OS shutdown signal <-signalChan - + infoLogger.Printf("waiting for the server to shutdown gracefully...") + // cancel the context cancel() - infoLogger.Printf("Got OS shutdown signal, shutting down webhook server gracefully...") - s.Shutdown(context.Background()) //nolint:errcheck + // wait all server for shutdown + wg.Wait() + // time.Sleep(2 * time.Second) + infoLogger.Printf("All servers are down: bye...") +} + +func StarWebhook(ctx context.Context, wg *sync.WaitGroup, tlsC *tls.Config) (err error) { + s := httpserver.New(httpserver.WithAddr(webhookPort), httpserver.WithTLSConfig(tlsC)) + s.Router.Post(webhookPathMutate, serveHandler) + return s.Start(ctx, wg) } diff --git a/go.mod b/go.mod index ff51739..0516322 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/Masterminds/semver/v3 v3.3.0 github.com/containers/image/v5 v5.30.1 github.com/crazy-max/diun/v4 v4.28.0 + github.com/go-chi/chi/v5 v5.1.0 github.com/gookit/event v1.1.2 github.com/onsi/ginkgo/v2 v2.20.2 github.com/onsi/gomega v1.34.2 diff --git a/go.sum b/go.sum index 30d3415..6990074 100644 --- a/go.sum +++ b/go.sum @@ -60,6 +60,8 @@ github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nos github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= +github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw= +github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= diff --git a/internal/health/health.go b/internal/health/health.go index 51a57d8..17b5a23 100644 --- a/internal/health/health.go +++ b/internal/health/health.go @@ -2,12 +2,13 @@ package health import ( "context" - "errors" "flag" - "log" "net" "net/http" + "sync" "time" + + "github.com/orange-cloudavenue/kube-image-updater/internal/httpserver" ) const ( @@ -24,12 +25,9 @@ func init() { flag.StringVar(&healthPath, "health-path", healthPath, "Health server path. ex: /healthz") } -// ServeHealth starts the health check server -func ServeHealth(ctx context.Context) (err error) { - // Define Health check server - mux := http.NewServeMux() - mux.HandleFunc(healthPath, func(w http.ResponseWriter, r *http.Request) { - // TODO - Add more health checks like use of kube client on kube api server +// healthHandler returns a http.Handler that returns a health check response +func healthHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := net.DialTimeout("tcp", healthPort, timeoutR) if err != nil { return @@ -41,35 +39,11 @@ func ServeHealth(ctx context.Context) (err error) { return } }) +} - // create health check server - s := &http.Server{ - Addr: healthPort, - Handler: mux, - ReadTimeout: 10 * timeoutR, - } - - // start the HTTP server - go func() { - log.Printf("Starting health check server on %s", s.Addr) - if err = s.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { - return - } - }() - - // kill the server if there is an error - go func() { - for { - <-ctx.Done() - ctxTimeout, cancel := context.WithTimeout(ctx, 5*time.Second) - log.Printf("Shutting down health check server on %s", s.Addr) - defer cancel() - if err = s.Shutdown(ctxTimeout); err != nil { - log.Printf("Failed to shutdown health check server: %v", err) - } - return - } - }() - - return nil +// ServeHealth starts the health check server +func StartHealth(ctx context.Context, wg *sync.WaitGroup) (err error) { + s := httpserver.New(httpserver.WithAddr(healthPort)) + s.AddGetRoutes(healthPath, healthHandler()) + return s.Start(ctx, wg) } diff --git a/internal/httpserver/httpserver.go b/internal/httpserver/httpserver.go new file mode 100644 index 0000000..007893c --- /dev/null +++ b/internal/httpserver/httpserver.go @@ -0,0 +1,131 @@ +package httpserver + +import ( + "context" + "crypto/tls" + "errors" + "net/http" + "sync" + "time" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + log "github.com/sirupsen/logrus" +) + +const ( + timeout = 10 * time.Second + defaultPort = ":8080" +) + +type ( + HTTPServer struct { + Router *chi.Mux + Config *http.Server + } + Option func(s *http.Server) +) + +// NewHTTPServer returns a new HTTP router +// func New(path, port string, tlsC *tls.Config) (s HTTPServer) { +func New(opts ...Option) *HTTPServer { + s := &HTTPServer{} + s.Router = chi.NewRouter() + s.Router.Use(middleware.Logger) + s.Router.Use(middleware.Timeout(timeout)) + + // Default server configuration + s.Config = &http.Server{ + Addr: defaultPort, + Handler: s.Router, + ReadTimeout: timeout, + } + for _, opt := range opts { + opt(s.Config) + } + return s +} + +// WithTLSConfig sets the TLS configuration for the HTTP server +// Add an option to set the TLS configuration for the HTTP server +// The WithTLSConfig function takes a *tls.Config as an argument and returns an Option +// The Option type is a function that takes a *http.Server as an argument +// +// ex: New(httpserver.WithTLSConfig(tlsC)) +// ex: New(httpserver.WithTLSConfig(tlsC), httpserver.WithAddr(":8443")) +func WithTLSConfig(tlsC *tls.Config) Option { + return func(s *http.Server) { + s.TLSConfig = tlsC + } +} + +// WithAddr sets the address for the HTTP server +// Add an option to set the address for the HTTP server +// The WithAddr function takes a string as an argument and returns an Option +// The Option type is a function that takes a *http.Server as an argument +// +// ex: New(httpserver.WithAddr(":8443")) +// ex: New(httpserver.WithTLSConfig(tlsC), httpserver.WithAddr(":8443")) +func WithAddr(addr string) Option { + return func(s *http.Server) { + s.Addr = addr + } +} + +// Add Get routes to the HTTP server +func (s HTTPServer) AddGetRoutes(path string, handler http.Handler) { + s.Router.Mount(path, handler) +} + +// Add Post routes to the HTTP server +func (s HTTPServer) AddPostRoutes(path string, handler http.Handler) { + s.Router.Mount(path, handler) +} + +// ServeHTTP implements the http.Handler interface +func (s HTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.Router.ServeHTTP(w, r) +} + +// ListenAndServe starts the HTTP server +func (s HTTPServer) Start(ctx context.Context, wg *sync.WaitGroup) (err error) { + wg.Add(1) + defer wg.Done() + + switch s.Config.TLSConfig { + case nil: + // Start the HTTP server + go func() { + log.Infof("Starting server on %s", s.Config.Addr) + // log.Printf("Starting server on %s", s.Config.Addr) + if err = s.Config.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { + return + } + }() + + default: + // Start the HTTPS server + go func() { + log.Infof("Starting TLS server on %s", s.Config.Addr) + if err = s.Config.ListenAndServeTLS("", ""); !errors.Is(err, http.ErrServerClosed) { + return + } + }() + } + + // Kill the server if there is an error or stop signal + go func() { + for { + <-ctx.Done() + defer wg.Done() + ctxTimeout, cancel := context.WithTimeout(ctx, 5*time.Second) + log.Infof("Shutting down server on %s", s.Config.Addr) + cancel() + if err = s.Config.Shutdown(ctxTimeout); err != nil { + log.Errorf("Failed to shutdown HTTP server on %s: %v", s.Config.Addr, err) + } + return + } + }() + return nil +} diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index bd716a1..9e0edcb 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -2,15 +2,14 @@ package metrics import ( "context" - "errors" "flag" - "log" - "net/http" - "time" + "sync" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promhttp" + + "github.com/orange-cloudavenue/kube-image-updater/internal/httpserver" ) var ( @@ -78,38 +77,8 @@ func init() { // ServeProm starts a Prometheus metrics server // TODO - Add context to cancel the server // in order to stop the server gracefully -func ServeProm(ctx context.Context) (err error) { - // Define Metrics server - mux := http.NewServeMux() - mux.Handle(metricsPath, promhttp.Handler()) - - sm := &http.Server{ - Addr: metricsPort, - Handler: mux, - ReadTimeout: 10 * time.Second, - } - - // Start the metrics server - go func() { - log.Printf("Starting metrics server on %s", metricsPort) - if err = sm.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { - return - } - }() - - // Kill the server if there is an error - go func() { - for { - <-ctx.Done() - ctxTimeout, cancel := context.WithTimeout(ctx, 5*time.Second) - log.Printf("Shutting down metrics server on %s", sm.Addr) - defer cancel() - if err = sm.Shutdown(ctxTimeout); err != nil { - log.Printf("Failed to shutdown metrics server: %v", err) - } - return - } - }() - - return nil +func StartProm(ctx context.Context, wg *sync.WaitGroup) (err error) { + s := httpserver.New(httpserver.WithAddr(metricsPort)) + s.AddGetRoutes(metricsPath, promhttp.Handler()) + return s.Start(ctx, wg) }