diff --git a/README.md b/README.md index 9a69e78a..533c66d9 100644 --- a/README.md +++ b/README.md @@ -167,7 +167,7 @@ spec: source: repoURL: https://unikorn-cloud.github.io/unikorn chart: unikorn - targetRevision: v0.1.0 + targetRevision: v0.1.1 destination: namespace: unikorn server: https://kubernetes.default.svc diff --git a/charts/unikorn/Chart.yaml b/charts/unikorn/Chart.yaml index 2db0ac7f..e79d05ba 100644 --- a/charts/unikorn/Chart.yaml +++ b/charts/unikorn/Chart.yaml @@ -4,7 +4,7 @@ description: A Helm chart for deploying Unikorn type: application -version: v0.1.0 -appVersion: v0.1.0 +version: v0.1.1 +appVersion: v0.1.1 icon: https://raw.githubusercontent.com/unikorn-cloud/unikorn/main/icons/default.png diff --git a/charts/unikorn/templates/unikorn-server.yaml b/charts/unikorn/templates/unikorn-server.yaml index 58e1585e..5e3c316d 100644 --- a/charts/unikorn/templates/unikorn-server.yaml +++ b/charts/unikorn/templates/unikorn-server.yaml @@ -168,13 +168,16 @@ metadata: labels: {{- include "unikorn.labels" . | nindent 4 }} annotations: - {{- if .Values.server.ingress.annotations }} - {{ toYaml .Values.server.ingress.annotations | indent 2 }} + {{- if .Values.server.ingress.issuer }} + cert-manager.io/issuer: {{ .Values.server.ingress.issuer }} {{- else }} - cert-manager.io/issuer: "unikorn-server-ingress" + cert-manager.io/issuer: unikorn-server-ingress + {{- end }} + {{- if .Values.server.ingress.externalDns }} + external-dns.alpha.kubernetes.io/hostname: {{ .Values.server.ingress.host }} {{- end }} spec: - ingressClassName: {{ .Values.server.ingress.ingressClass }} + ingressClassName: {{ .Values.server.ingress.class }} # For development you will want to add these names to /etc/hosts for the ingress # endpoint address. tls: @@ -188,20 +191,11 @@ spec: - host: {{ .Values.server.ingress.host }} http: paths: - - path: /api - pathType: Prefix - backend: - service: - name: unikorn-server - port: - name: http -{{- if .Values.ui.enabled }} - path: / pathType: Prefix backend: service: - name: unikorn-ui + name: unikorn-server port: name: http {{- end }} -{{- end }} diff --git a/charts/unikorn/values.yaml b/charts/unikorn/values.yaml index 7e412e8e..78e0adb6 100644 --- a/charts/unikorn/values.yaml +++ b/charts/unikorn/values.yaml @@ -99,20 +99,18 @@ server: ingress: # Sets the ingress class to use. - ingressClass: nginx - - # A map of explicit annotations to add to the ingress. By default, when not - # specified, the chart will create an issuer and add in an annotation to generate - # self signed TLS secret with cert-manager. For real life deployments, you will - # want something like the following e.g. a shared cluster issuer, and external-dns - # to define the DNS address via DDNS and keep the IP address in sync. - # annotations: - # external-dns.alpha.kubernetes.io/hostname=unikorn.unikorn-cloud.org - # cert-manager.io/issuer: letsencrypt-prod + class: nginx # Sets the DNS hosts/X.509 Certs. host: unikorn.unikorn-cloud.org + # Cert Manager certificate issuer to use. If not specified it will generate a + # self signed one. + issuer: ~ + + # If true, will add the external DNS hostname annotation. + externalDns: false + oidc: # OIDC issuer used to discover OIDC configuration and verify access tokens. issuer: https://identity.unikorn-cloud.org @@ -124,11 +122,6 @@ server: # Sets the OTLP endpoint for shipping spans. # otlpEndpoint: jaeger-collector.default:4318 -# UI that works with the server. -ui: - # Temporarily block deployment until it's complete. - enabled: false - # Defines Prometheus monitoring integration. monitoring: # Enable monitoring, ensure Prometheus is installed first to define the CRDs. diff --git a/pkg/server/middleware/cors/cors.go b/pkg/server/middleware/cors/cors.go new file mode 100644 index 00000000..49f06c18 --- /dev/null +++ b/pkg/server/middleware/cors/cors.go @@ -0,0 +1,88 @@ +/* +Copyright 2024 the Unikorn Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cors + +import ( + "net/http" + "strconv" + "strings" + + "github.com/spf13/pflag" + + "github.com/unikorn-cloud/core/pkg/util" + "github.com/unikorn-cloud/unikorn/pkg/server/errors" + "github.com/unikorn-cloud/unikorn/pkg/server/middleware/openapi" +) + +type Options struct { + AllowedOrigins []string + MaxAge int +} + +func (o *Options) AddFlags(f *pflag.FlagSet) { + f.StringSliceVar(&o.AllowedOrigins, "--cors-allow-origin", []string{"*"}, "CORS allowed origins") + f.IntVar(&o.MaxAge, "--cors-max-age", 86400, "CORS maximum age (may be overridden by the browser)") +} + +func Middleware(schema *openapi.Schema, options *Options) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // All requests get the allow origin header. + for _, origin := range options.AllowedOrigins { + w.Header().Add("Access-Control-Allow-Origin", origin) + } + + // For normal requests handle them. + if r.Method != http.MethodOptions { + next.ServeHTTP(w, r) + return + } + + // Handle preflight + method := r.Header.Get("Access-Control-Request-Method") + if method == "" { + errors.HandleError(w, r, errors.OAuth2InvalidRequest("OPTIONS missing Access-Control-Request-Method header")) + return + } + + request := r.Clone(r.Context()) + request.Method = method + + route, _, err := schema.FindRoute(request) + if err != nil { + errors.HandleError(w, r, err) + return + } + + // TODO: add OPTIONS to the schema? + methods := util.Keys(route.PathItem.Operations()) + methods = append(methods, http.MethodOptions) + + // TODO: get these from the schema. + headers := []string{ + "Authorization", + "traceparent", + "tracestate", + } + + w.Header().Add("Access-Control-Allow-Methods", strings.Join(methods, ", ")) + w.Header().Add("Access-Control-Allow-Headers", strings.Join(headers, ", ")) + w.Header().Add("Access-Control-Max-Age", strconv.Itoa(options.MaxAge)) + w.WriteHeader(http.StatusNoContent) + }) + } +} diff --git a/pkg/server/middleware/authorization.go b/pkg/server/middleware/openapi/authorization.go similarity index 95% rename from pkg/server/middleware/authorization.go rename to pkg/server/middleware/openapi/authorization.go index 681aff8d..f3caac42 100644 --- a/pkg/server/middleware/authorization.go +++ b/pkg/server/middleware/openapi/authorization.go @@ -15,7 +15,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package middleware +package openapi import ( "crypto/tls" @@ -44,7 +44,7 @@ type authorizationContext struct { claims oauth2.Claims } -type AuthorizerOptions struct { +type Options struct { // issuer is used to perform OIDC discovery and verify access tokens // using the JWKS endpoint. issuer string @@ -53,18 +53,18 @@ type AuthorizerOptions struct { issuerCA []byte } -func (o *AuthorizerOptions) AddFlags(f *pflag.FlagSet) { +func (o *Options) AddFlags(f *pflag.FlagSet) { f.StringVar(&o.issuer, "oidc-issuer", "", "OIDC issuer URL to use for token validation.") f.BytesBase64Var(&o.issuerCA, "oidc-issuer-ca", nil, "base64 OIDC endpoint CA certificate.") } // Authorizer provides OpenAPI based authorization middleware. type Authorizer struct { - options *AuthorizerOptions + options *Options } // NewAuthorizer returns a new authorizer with required parameters. -func NewAuthorizer(options *AuthorizerOptions) *Authorizer { +func NewAuthorizer(options *Options) *Authorizer { return &Authorizer{ options: options, } diff --git a/pkg/server/middleware/openapi.go b/pkg/server/middleware/openapi/openapi.go similarity index 82% rename from pkg/server/middleware/openapi.go rename to pkg/server/middleware/openapi/openapi.go index 5292310e..48d48450 100644 --- a/pkg/server/middleware/openapi.go +++ b/pkg/server/middleware/openapi/openapi.go @@ -15,7 +15,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package middleware +package openapi import ( "bytes" @@ -34,26 +34,26 @@ import ( "sigs.k8s.io/controller-runtime/pkg/log" ) -// OpenAPIValidator provides OpenAPI validation of request and response codes, +// Validator provides Schema validation of request and response codes, // media, and schema validation of payloads to ensure we are meeting the // specification. -type OpenAPIValidator struct { +type Validator struct { // next defines the next HTTP handler in the chain. next http.Handler // authorizer provides security policy enforcement. authorizer *Authorizer - // openapi caches the OpenAPI schema. - openapi *OpenAPI + // openapi caches the Schema schema. + openapi *Schema } // Ensure this implements the required interfaces. -var _ http.Handler = &OpenAPIValidator{} +var _ http.Handler = &Validator{} -// NewOpenAPIValidator returns an initialized validator middleware. -func NewOpenAPIValidator(authorizer *Authorizer, next http.Handler, openapi *OpenAPI) *OpenAPIValidator { - return &OpenAPIValidator{ +// NewValidator returns an initialized validator middleware. +func NewValidator(authorizer *Authorizer, next http.Handler, openapi *Schema) *Validator { + return &Validator{ authorizer: authorizer, next: next, openapi: openapi, @@ -109,13 +109,13 @@ func (w *bufferingResponseWriter) StatusCode() int { return w.code } -func (v *OpenAPIValidator) validateRequest(r *http.Request, authContext *authorizationContext) (*openapi3filter.ResponseValidationInput, error) { +func (v *Validator) validateRequest(r *http.Request, authContext *authorizationContext) (*openapi3filter.ResponseValidationInput, error) { tracer := otel.GetTracerProvider().Tracer(constants.Application) ctx, span := tracer.Start(r.Context(), "openapi request validation", trace.WithSpanKind(trace.SpanKindInternal)) defer span.End() - route, params, err := v.openapi.findRoute(r) + route, params, err := v.openapi.FindRoute(r) if err != nil { return nil, errors.OAuth2ServerError("route lookup failure").WithError(err) } @@ -159,7 +159,7 @@ func (v *OpenAPIValidator) validateRequest(r *http.Request, authContext *authori return responseValidationInput, nil } -func (v *OpenAPIValidator) validateResponse(w *bufferingResponseWriter, r *http.Request, responseValidationInput *openapi3filter.ResponseValidationInput) { +func (v *Validator) validateResponse(w *bufferingResponseWriter, r *http.Request, responseValidationInput *openapi3filter.ResponseValidationInput) { tracer := otel.GetTracerProvider().Tracer(constants.Application) ctx, span := tracer.Start(r.Context(), "openapi response validation", @@ -177,7 +177,7 @@ func (v *OpenAPIValidator) validateResponse(w *bufferingResponseWriter, r *http. } // ServeHTTP implements the http.Handler interface. -func (v *OpenAPIValidator) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (v *Validator) ServeHTTP(w http.ResponseWriter, r *http.Request) { authContext := &authorizationContext{} responseValidationInput, err := v.validateRequest(r, authContext) @@ -200,10 +200,10 @@ func (v *OpenAPIValidator) ServeHTTP(w http.ResponseWriter, r *http.Request) { v.validateResponse(writer, r, responseValidationInput) } -// OpenAPIValidatorMiddlewareFactory returns a function that generates per-request +// Middleware returns a function that generates per-request // middleware functions. -func OpenAPIValidatorMiddlewareFactory(authorizer *Authorizer, openapi *OpenAPI) func(http.Handler) http.Handler { +func Middleware(authorizer *Authorizer, openapi *Schema) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { - return NewOpenAPIValidator(authorizer, next, openapi) + return NewValidator(authorizer, next, openapi) } } diff --git a/pkg/server/middleware/util.go b/pkg/server/middleware/openapi/schema.go similarity index 80% rename from pkg/server/middleware/util.go rename to pkg/server/middleware/openapi/schema.go index 4f60120b..f348f120 100644 --- a/pkg/server/middleware/util.go +++ b/pkg/server/middleware/openapi/schema.go @@ -15,7 +15,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package middleware +package openapi import ( "net/http" @@ -28,8 +28,8 @@ import ( "github.com/unikorn-cloud/unikorn/pkg/server/generated" ) -// OpenAPI abstracts schema access and validation. -type OpenAPI struct { +// Schema abstracts schema access and validation. +type Schema struct { // spec is the full specification. spec *openapi3.T @@ -40,7 +40,7 @@ type OpenAPI struct { // NewOpenRpi extracts the swagger document. // NOTE: this is surprisingly slow, make sure you cache it and reuse it. -func NewOpenAPI() (*OpenAPI, error) { +func NewSchema() (*Schema, error) { spec, err := generated.GetSwagger() if err != nil { return nil, err @@ -51,17 +51,17 @@ func NewOpenAPI() (*OpenAPI, error) { return nil, err } - o := &OpenAPI{ + s := &Schema{ spec: spec, router: router, } - return o, nil + return s, nil } -// findRoute looks up the route from the specification. -func (o *OpenAPI) findRoute(r *http.Request) (*routers.Route, map[string]string, error) { - route, params, err := o.router.FindRoute(r) +// FindRoute looks up the route from the specification. +func (s *Schema) FindRoute(r *http.Request) (*routers.Route, map[string]string, error) { + route, params, err := s.router.FindRoute(r) if err != nil { return nil, nil, errors.OAuth2ServerError("unable to find route").WithError(err) } diff --git a/pkg/server/middleware/opentelemetry.go b/pkg/server/middleware/opentelemetry/opentelemetry.go similarity index 98% rename from pkg/server/middleware/opentelemetry.go rename to pkg/server/middleware/opentelemetry/opentelemetry.go index a8f4f725..56076684 100644 --- a/pkg/server/middleware/opentelemetry.go +++ b/pkg/server/middleware/opentelemetry/opentelemetry.go @@ -15,7 +15,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package middleware +package opentelemetry import ( "context" @@ -256,8 +256,8 @@ func httpStatusToOtelCode(status int) (codes.Code, string) { return code, http.StatusText(status) } -// Logger attaches logging context to the request. -func Logger() func(next http.Handler) http.Handler { +// Middleware attaches logging context to the request. +func Middleware() func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Extract the tracing information from the HTTP headers. diff --git a/pkg/server/middleware/timeout.go b/pkg/server/middleware/timeout/timeout.go similarity index 87% rename from pkg/server/middleware/timeout.go rename to pkg/server/middleware/timeout/timeout.go index 8db68ed5..152fad0f 100644 --- a/pkg/server/middleware/timeout.go +++ b/pkg/server/middleware/timeout/timeout.go @@ -15,7 +15,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -package middleware +package timeout import ( "context" @@ -23,8 +23,8 @@ import ( "time" ) -// Timeout adds a timeout to requests. -func Timeout(timeout time.Duration) func(http.Handler) http.Handler { +// Middleware adds a timeout to requests. +func Middleware(timeout time.Duration) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), timeout) diff --git a/pkg/server/server.go b/pkg/server/server.go index cf945cce..0c8f126b 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -31,7 +31,10 @@ import ( "github.com/unikorn-cloud/unikorn/pkg/server/generated" "github.com/unikorn-cloud/unikorn/pkg/server/handler" - "github.com/unikorn-cloud/unikorn/pkg/server/middleware" + "github.com/unikorn-cloud/unikorn/pkg/server/middleware/cors" + "github.com/unikorn-cloud/unikorn/pkg/server/middleware/openapi" + "github.com/unikorn-cloud/unikorn/pkg/server/middleware/opentelemetry" + "github.com/unikorn-cloud/unikorn/pkg/server/middleware/timeout" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/log" @@ -48,7 +51,11 @@ type Server struct { // HandlerOptions sets options for the HTTP handler. HandlerOptions handler.Options - AuthorizerOptions middleware.AuthorizerOptions + // AuthorizerOptions allow configuration of the OIDC backend. + AuthorizerOptions openapi.Options + + // CORSOptions are for remote resource sharing. + CORSOptions cors.Options } func (s *Server) AddFlags(goflags *flag.FlagSet, flags *pflag.FlagSet) { @@ -57,6 +64,7 @@ func (s *Server) AddFlags(goflags *flag.FlagSet, flags *pflag.FlagSet) { s.Options.AddFlags(flags) s.HandlerOptions.AddFlags(flags) s.AuthorizerOptions.AddFlags(flags) + s.CORSOptions.AddFlags(flags) } func (s *Server) SetupLogging() { @@ -72,7 +80,7 @@ func (s *Server) SetupOpenTelemetry(ctx context.Context) error { otel.SetTextMapPropagator(propagation.TraceContext{}) opts := []trace.TracerProviderOption{ - trace.WithSpanProcessor(&middleware.LoggingSpanProcessor{}), + trace.WithSpanProcessor(&opentelemetry.LoggingSpanProcessor{}), } if s.Options.OTLPEndpoint != "" { @@ -94,20 +102,21 @@ func (s *Server) SetupOpenTelemetry(ctx context.Context) error { } func (s *Server) GetServer(client client.Client) (*http.Server, error) { + schema, err := openapi.NewSchema() + if err != nil { + return nil, err + } + // Middleware specified here is applied to all requests pre-routing. router := chi.NewRouter() - router.Use(middleware.Logger()) - router.Use(middleware.Timeout(s.Options.RequestTimeout)) + router.Use(timeout.Middleware(s.Options.RequestTimeout)) + router.Use(opentelemetry.Middleware()) + router.Use(cors.Middleware(schema, &s.CORSOptions)) router.NotFound(http.HandlerFunc(handler.NotFound)) router.MethodNotAllowed(http.HandlerFunc(handler.MethodNotAllowed)) // Setup middleware. - authorizer := middleware.NewAuthorizer(&s.AuthorizerOptions) - - openapi, err := middleware.NewOpenAPI() - if err != nil { - return nil, err - } + authorizer := openapi.NewAuthorizer(&s.AuthorizerOptions) // Middleware specified here is applied to all requests post-routing. // NOTE: these are applied in reverse order!! @@ -115,7 +124,7 @@ func (s *Server) GetServer(client client.Client) (*http.Server, error) { BaseRouter: router, ErrorHandlerFunc: handler.HandleError, Middlewares: []generated.MiddlewareFunc{ - middleware.OpenAPIValidatorMiddlewareFactory(authorizer, openapi), + openapi.Middleware(authorizer, schema), }, }