diff --git a/internal/amf/amf.go b/internal/amf/amf.go index 7dbb47a..b177cdd 100644 --- a/internal/amf/amf.go +++ b/internal/amf/amf.go @@ -27,6 +27,7 @@ type Amf struct { userAgent string smf *smf.Smf srv *http.Server + closed chan struct{} } func NewAmf(bindAddr netip.AddrPort, control jsonapi.ControlURI, userAgent string, smf *smf.Smf) *Amf { @@ -35,6 +36,7 @@ func NewAmf(bindAddr netip.AddrPort, control jsonapi.ControlURI, userAgent strin client: http.Client{}, userAgent: userAgent, smf: smf, + closed: make(chan struct{}), } // TODO: gin.SetMode(gin.DebugMode) / gin.SetMode(gin.ReleaseMode) depending on log level r := gin.Default() @@ -65,19 +67,28 @@ func (amf *Amf) Start(ctx context.Context) error { } }(l) go func(ctx context.Context) { + defer close(amf.closed) select { case <-ctx.Done(): - ctxShutdown, cancel := context.WithTimeout(context.Background(), 1*time.Second) + ctxShutdown, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() if err := amf.srv.Shutdown(ctxShutdown); err != nil { logrus.WithError(err).Info("HTTP Server Shutdown") } } }(ctx) - return nil } +func (amf *Amf) WaitShutdown(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-amf.closed: + return nil + } +} + // get status of the controller func Status(c *gin.Context) { status := healthcheck.Status{ diff --git a/internal/app/setup.go b/internal/app/setup.go index 8c3953a..e420acc 100644 --- a/internal/app/setup.go +++ b/internal/app/setup.go @@ -7,6 +7,7 @@ package app import ( "context" + "time" "github.com/nextmn/cp-lite/internal/amf" "github.com/nextmn/cp-lite/internal/config" @@ -27,27 +28,20 @@ func NewSetup(config *config.CPConfig) *Setup { smf: smf, } } -func (s *Setup) Init(ctx context.Context) error { + +func (s *Setup) Run(ctx context.Context) error { if err := s.smf.Start(ctx); err != nil { return err } if err := s.amf.Start(ctx); err != nil { return err } - return nil -} - -func (s *Setup) Run(ctx context.Context) error { - defer s.Exit() - if err := s.Init(ctx); err != nil { - return err - } select { case <-ctx.Done(): - return nil + ctxShutdown, cancel := context.WithTimeout(ctx, 1*time.Second) + defer cancel() + s.amf.WaitShutdown(ctxShutdown) + s.smf.WaitShutdown(ctxShutdown) } -} - -func (s *Setup) Exit() error { return nil } diff --git a/internal/smf/smf.go b/internal/smf/smf.go index a734b2f..315b76f 100644 --- a/internal/smf/smf.go +++ b/internal/smf/smf.go @@ -26,6 +26,7 @@ type Smf struct { slices *SlicesMap srv *pfcp.PFCPEntityCP started bool + closed chan struct{} } func NewSmf(addr netip.Addr, slices map[string]config.Slice) *Smf { @@ -35,12 +36,14 @@ func NewSmf(addr netip.Addr, slices map[string]config.Slice) *Smf { srv: pfcp.NewPFCPEntityCP(addr.String(), addr), slices: s, upfs: upfs, + closed: make(chan struct{}), } } func (smf *Smf) Start(ctx context.Context) error { logrus.Info("Starting PFCP Server") go func() { + defer close(smf.closed) err := smf.srv.ListenAndServeContext(ctx) if err != nil { logrus.WithError(err).Trace("PFCP server stopped") @@ -204,3 +207,11 @@ func (smf *Smf) CreateSessionUplink(ctx context.Context, ueCtrl jsonapi.ControlU slice.sessions.Store(ueCtrl, &session) return &session, nil } +func (smf *Smf) WaitShutdown(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-smf.closed: + return nil + } +}