From d284ce40d79cee54d9f003d6d46a74617d9bd65e Mon Sep 17 00:00:00 2001 From: Leigh MacDonald Date: Fri, 21 Jun 2024 04:30:23 -0600 Subject: [PATCH] Add servers tests --- frontend/src/api/server.ts | 2 +- internal/auth/auth_usecase.go | 2 +- internal/demo/demo_usecase.go | 2 +- internal/discord/discord_service.go | 6 +- internal/domain/server.go | 39 +++++- internal/httphelper/http.go | 4 + internal/match/match_repository.go | 2 +- internal/match/match_service.go | 2 +- internal/match/match_usecase.go | 2 +- internal/network/scp.go | 2 +- internal/servers/servers_repository.go | 3 - internal/servers/servers_service.go | 161 ++++--------------------- internal/servers/servers_usecase.go | 79 ++++++++---- internal/srcds/srcds_service.go | 2 +- internal/state/collector.go | 2 +- internal/state/state_usecase.go | 2 +- internal/test/main_test.go | 5 +- internal/test/servers_test.go | 134 ++++++++++++++++++++ 18 files changed, 267 insertions(+), 184 deletions(-) create mode 100644 internal/test/servers_test.go diff --git a/frontend/src/api/server.ts b/frontend/src/api/server.ts index a1711996..7f636f2d 100644 --- a/frontend/src/api/server.ts +++ b/frontend/src/api/server.ts @@ -106,7 +106,7 @@ export const apiSaveServer = async (server_id: number, opts: SaveServerOpts) => }; export const apiGetServersAdmin = async (abortController?: AbortController) => { - const resp = await apiCall(`/api/servers_admin`, 'POST', undefined, abortController); + const resp = await apiCall(`/api/servers_admin`, 'GET', undefined, abortController); return resp.map(transformTimeStampedDates).map((s) => { s.token_created_on = parseDateTime(s.token_created_on as unknown as string); return s; diff --git a/internal/auth/auth_usecase.go b/internal/auth/auth_usecase.go index a9a4b5ef..511e6a85 100644 --- a/internal/auth/auth_usecase.go +++ b/internal/auth/auth_usecase.go @@ -213,7 +213,7 @@ func (u *auth) AuthServerMiddleWare() gin.HandlerFunc { } var server domain.Server - if errServer := u.servers.GetServerByPassword(ctx, reqAuthHeader, &server, false, false); errServer != nil { + if errServer := u.servers.GetByPassword(ctx, reqAuthHeader, &server, false, false); errServer != nil { slog.Error("Failed to load server during auth", log.ErrAttr(errServer), slog.String("token", reqAuthHeader), slog.String("IP", ctx.ClientIP())) ctx.AbortWithStatus(http.StatusUnauthorized) diff --git a/internal/demo/demo_usecase.go b/internal/demo/demo_usecase.go index aaf6f515..e7f15c56 100644 --- a/internal/demo/demo_usecase.go +++ b/internal/demo/demo_usecase.go @@ -210,7 +210,7 @@ func (d demoUsecase) GetDemos(ctx context.Context) ([]domain.DemoFile, error) { } func (d demoUsecase) CreateFromAsset(ctx context.Context, asset domain.Asset, serverID int) (*domain.DemoFile, error) { - _, errGetServer := d.servers.GetServer(ctx, serverID) + _, errGetServer := d.servers.Server(ctx, serverID) if errGetServer != nil { return nil, domain.ErrGetServer } diff --git a/internal/discord/discord_service.go b/internal/discord/discord_service.go index c625a0a1..c6f7e659 100644 --- a/internal/discord/discord_service.go +++ b/internal/discord/discord_service.go @@ -366,7 +366,7 @@ func (h discordService) makeOnSay() func(context.Context, *discordgo.Session, *d msg := opts[domain.OptMessage].StringValue() var server domain.Server - if err := h.servers.GetServerByName(ctx, serverName, &server, false, false); err != nil { + if err := h.servers.GetByName(ctx, serverName, &server, false, false); err != nil { return nil, domain.ErrUnknownServer } @@ -386,7 +386,7 @@ func (h discordService) makeOnCSay() func(_ context.Context, _ *discordgo.Sessio msg := opts[domain.OptMessage].StringValue() var server domain.Server - if err := h.servers.GetServerByName(ctx, serverName, &server, false, false); err != nil { + if err := h.servers.GetByName(ctx, serverName, &server, false, false); err != nil { return nil, domain.ErrUnknownServer } @@ -665,7 +665,7 @@ func (h discordService) makeOnFind() func(context.Context, *discordgo.Session, * var found []domain.FoundPlayer for _, player := range players { - server, errServer := h.servers.GetServer(ctx, player.ServerID) + server, errServer := h.servers.Server(ctx, player.ServerID) if errServer != nil { return nil, errors.Join(errServer, domain.ErrGetServer) } diff --git a/internal/domain/server.go b/internal/domain/server.go index 94b5e5a5..b16af31c 100644 --- a/internal/domain/server.go +++ b/internal/domain/server.go @@ -12,13 +12,40 @@ import ( "github.com/leighmacdonald/steamid/v4/steamid" ) +type RequestServerUpdate struct { + ServerID int `json:"server_id"` + ServerName string `json:"server_name"` + ServerNameShort string `json:"server_name_short"` + Host string `json:"host"` + Port int `json:"port"` + ReservedSlots int `json:"reserved_slots"` + Password string `json:"password"` + RCON string `json:"rcon"` + Lat float64 `json:"lat"` + Lon float64 `json:"lon"` + CC string `json:"cc"` + DefaultMap string `json:"default_map"` + Region string `json:"region"` + IsEnabled bool `json:"is_enabled"` + EnableStats bool `json:"enable_stats"` + LogSecret int `json:"log_secret"` +} + +type ServerInfoSafe struct { + ServerNameLong string `json:"server_name_long"` + ServerName string `json:"server_name"` + ServerID int `json:"server_id"` + Colour string `json:"colour"` +} + type ServersUsecase interface { - GetServer(ctx context.Context, serverID int) (Server, error) - GetServerPermissions(ctx context.Context) ([]ServerPermission, error) - GetServers(ctx context.Context, filter ServerQueryFilter) ([]Server, int64, error) - GetServerByName(ctx context.Context, serverName string, server *Server, disabledOk bool, deletedOk bool) error - GetServerByPassword(ctx context.Context, serverPassword string, server *Server, disabledOk bool, deletedOk bool) error - SaveServer(ctx context.Context, server *Server) error + Server(ctx context.Context, serverID int) (Server, error) + ServerPermissions(ctx context.Context) ([]ServerPermission, error) + Servers(ctx context.Context, filter ServerQueryFilter) ([]Server, int64, error) + GetByName(ctx context.Context, serverName string, server *Server, disabledOk bool, deletedOk bool) error + GetByPassword(ctx context.Context, serverPassword string, server *Server, disabledOk bool, deletedOk bool) error + Save(ctx context.Context, req RequestServerUpdate) (Server, error) + Delete(ctx context.Context, serverID int) error } type ServersRepository interface { diff --git a/internal/httphelper/http.go b/internal/httphelper/http.go index 7a9380c4..1ce2c670 100644 --- a/internal/httphelper/http.go +++ b/internal/httphelper/http.go @@ -87,6 +87,10 @@ func HandleErrs(ctx *gin.Context, err error) { HandleErrPermissionDenied(ctx) case errors.Is(err, domain.ErrNoResult): HandleErrNotFound(ctx) + case errors.Is(err, domain.ErrParamKeyMissing): + HandleErrBadRequest(ctx) + case errors.Is(err, domain.ErrInvalidParameter): + HandleErrBadRequest(ctx) case errors.Is(err, domain.ErrBadRequest): HandleErrBadRequest(ctx) case errors.Is(err, domain.ErrDuplicate): diff --git a/internal/match/match_repository.go b/internal/match/match_repository.go index 6659e15b..8ab10969 100644 --- a/internal/match/match_repository.go +++ b/internal/match/match_repository.go @@ -66,7 +66,7 @@ func (r *matchRepository) onMatchComplete(ctx context.Context, matchContext *act matchContext.match.Title = server.Name } - fullServer, err := r.servers.GetServer(ctx, server.ServerID) + fullServer, err := r.servers.Server(ctx, server.ServerID) if err != nil { return errors.Join(err, domain.ErrLoadServer) } diff --git a/internal/match/match_service.go b/internal/match/match_service.go index b6440a38..e582b7b6 100644 --- a/internal/match/match_service.go +++ b/internal/match/match_service.go @@ -96,7 +96,7 @@ func (h matchHandler) onAPIPostMatchStart() gin.HandlerFunc { return } - server, errServer := h.servers.GetServer(ctx, serverID) + server, errServer := h.servers.Server(ctx, serverID) if errServer != nil { httphelper.ResponseErr(ctx, http.StatusInternalServerError, domain.ErrUnknownServerID) slog.Error("Failed to get server", log.ErrAttr(errServer)) diff --git a/internal/match/match_usecase.go b/internal/match/match_usecase.go index ecdec1f1..8d67b5dc 100644 --- a/internal/match/match_usecase.go +++ b/internal/match/match_usecase.go @@ -54,7 +54,7 @@ func (m matchUsecase) EndMatch(ctx context.Context, serverID int) (uuid.UUID, er return matchID, domain.ErrLoadMatch } - server, errServer := m.servers.GetServer(ctx, serverID) + server, errServer := m.servers.Server(ctx, serverID) if errServer != nil { return matchID, errors.Join(errServer, domain.ErrUnknownServer) } diff --git a/internal/network/scp.go b/internal/network/scp.go index 278a8e74..a4be367e 100644 --- a/internal/network/scp.go +++ b/internal/network/scp.go @@ -73,7 +73,7 @@ func (f SCPExecer) Start(ctx context.Context) { } func (f SCPExecer) update(ctx context.Context) error { - servers, _, errServers := f.serversUsecase.GetServers(ctx, domain.ServerQueryFilter{}) + servers, _, errServers := f.serversUsecase.Servers(ctx, domain.ServerQueryFilter{}) if errServers != nil { return errServers } diff --git a/internal/servers/servers_repository.go b/internal/servers/servers_repository.go index 5b5148b8..59c08eac 100644 --- a/internal/servers/servers_repository.go +++ b/internal/servers/servers_repository.go @@ -252,13 +252,10 @@ func (r *serversRepository) GetServerByPassword(ctx context.Context, serverPassw // SaveServer updates or creates the server data in the database. func (r *serversRepository) SaveServer(ctx context.Context, server *domain.Server) error { - server.UpdatedOn = time.Now() if server.ServerID > 0 { return r.updateServer(ctx, server) } - server.CreatedOn = time.Now() - return r.insertServer(ctx, server) } diff --git a/internal/servers/servers_service.go b/internal/servers/servers_service.go index 6d934135..1f00b2d9 100644 --- a/internal/servers/servers_service.go +++ b/internal/servers/servers_service.go @@ -1,12 +1,10 @@ package servers import ( - "fmt" "log/slog" "math" "net/http" "sort" - "strings" "github.com/gin-gonic/gin" "github.com/leighmacdonald/gbans/internal/domain" @@ -30,7 +28,6 @@ func NewServerHandler(engine *gin.Engine, serversUsecase domain.ServersUsecase, persons: personUsecase, } - engine.GET("/export/sourcemod/admins_simple.ini", handler.onAPIExportSourcemodSimpleAdmins()) engine.GET("/api/servers/state", handler.onAPIGetServerStates()) engine.GET("/api/servers", handler.onAPIGetServers()) @@ -41,69 +38,13 @@ func NewServerHandler(engine *gin.Engine, serversUsecase domain.ServersUsecase, admin.POST("/api/servers", handler.onAPIPostServer()) admin.POST("/api/servers/:server_id", handler.onAPIPostServerUpdate()) admin.DELETE("/api/servers/:server_id", handler.onAPIPostServerDelete()) - admin.POST("/api/servers_admin", handler.onAPIGetServersAdmin()) - } -} - -type serverInfoSafe struct { - ServerNameLong string `json:"server_name_long"` - ServerName string `json:"server_name"` - ServerID int `json:"server_id"` - Colour string `json:"colour"` -} - -func (h *serversHandler) onAPIExportSourcemodSimpleAdmins() gin.HandlerFunc { - return func(ctx *gin.Context) { - privilegedIDs, errPrivilegedIDs := h.persons.GetSteamIDsAbove(ctx, domain.PReserved) - if errPrivilegedIDs != nil { - httphelper.HandleErrInternal(ctx) - slog.Error("Failed to get steam ids", log.ErrAttr(errPrivilegedIDs)) - - return - } - - players, errPlayers := h.persons.GetPeopleBySteamID(ctx, privilegedIDs) - if errPlayers != nil { - httphelper.HandleErrInternal(ctx) - slog.Error("Failed to get people by steamid", log.ErrAttr(errPlayers)) - - return - } - - sort.Slice(players, func(i, j int) bool { - return players[i].PermissionLevel > players[j].PermissionLevel - }) - - bld := strings.Builder{} - - for _, player := range players { - var perms string - - switch player.PermissionLevel { - case domain.PAdmin: - perms = "z" - case domain.PModerator: - perms = "abcdefgjk" - case domain.PEditor: - perms = "ak" - case domain.PReserved: - perms = "a" - } - - if perms == "" { - slog.Warn("User has no perm string", slog.Int64("sid", player.SteamID.Int64())) - } else { - bld.WriteString(fmt.Sprintf("\"%s\" \"%s\"\n", player.SteamID.Steam3(), perms)) - } - } - - ctx.String(http.StatusOK, bld.String()) + admin.GET("/api/servers_admin", handler.onAPIGetServersAdmin()) } } func (h *serversHandler) onAPIGetServers() gin.HandlerFunc { return func(ctx *gin.Context) { - fullServers, _, errServers := h.servers.GetServers(ctx, domain.ServerQueryFilter{}) + fullServers, _, errServers := h.servers.Servers(ctx, domain.ServerQueryFilter{}) if errServers != nil { httphelper.HandleErrInternal(ctx) slog.Error("Failed to get servers", log.ErrAttr(errServers)) @@ -111,9 +52,9 @@ func (h *serversHandler) onAPIGetServers() gin.HandlerFunc { return } - var servers []serverInfoSafe + var servers []domain.ServerInfoSafe for _, server := range fullServers { - servers = append(servers, serverInfoSafe{ + servers = append(servers, domain.ServerInfoSafe{ ServerNameLong: server.Name, ServerName: server.ShortName, ServerID: server.ServerID, @@ -200,52 +141,24 @@ func (h *serversHandler) onAPIGetServerStates() gin.HandlerFunc { func (h *serversHandler) onAPIPostServer() gin.HandlerFunc { return func(ctx *gin.Context) { - var req serverUpdateRequest + var req domain.RequestServerUpdate if !httphelper.Bind(ctx, &req) { return } - server := domain.NewServer(req.ServerNameShort, req.Host, req.Port) - server.Name = req.ServerName - server.Password = req.Password - server.ReservedSlots = req.ReservedSlots - server.RCON = req.RCON - server.Latitude = req.Lat - server.Longitude = req.Lon - server.CC = req.CC - server.Region = req.Region - server.IsEnabled = req.IsEnabled - server.LogSecret = req.LogSecret - - if errSave := h.servers.SaveServer(ctx, &server); errSave != nil { - httphelper.HandleErrInternal(ctx) + server, errSave := h.servers.Save(ctx, req) + if errSave != nil { + httphelper.HandleErrs(ctx, errSave) slog.Error("Failed to save new server", log.ErrAttr(errSave)) return } ctx.JSON(http.StatusOK, server) + slog.Info("Created new server", slog.String("name", server.ShortName), slog.Int("server_id", server.ServerID)) } } -type serverUpdateRequest struct { - ServerName string `json:"server_name"` - ServerNameShort string `json:"server_name_short"` - Host string `json:"host"` - Port int `json:"port"` - ReservedSlots int `json:"reserved_slots"` - Password string `json:"password"` - RCON string `json:"rcon"` - Lat float64 `json:"lat"` - Lon float64 `json:"lon"` - CC string `json:"cc"` - DefaultMap string `json:"default_map"` - Region string `json:"region"` - IsEnabled bool `json:"is_enabled"` - EnableStats bool `json:"enable_stats"` - LogSecret int `json:"log_secret"` -} - func (h *serversHandler) onAPIPostServerUpdate() gin.HandlerFunc { return func(ctx *gin.Context) { serverID, errServerID := httphelper.GetIntParam(ctx, "server_id") @@ -256,42 +169,23 @@ func (h *serversHandler) onAPIPostServerUpdate() gin.HandlerFunc { return } - server, errServer := h.servers.GetServer(ctx, serverID) - if errServer != nil { - httphelper.HandleErrInternal(ctx) - slog.Error("Failed to get server", log.ErrAttr(errServer)) - - return - } - - var req serverUpdateRequest + var req domain.RequestServerUpdate if !httphelper.Bind(ctx, &req) { return } - server.ShortName = req.ServerNameShort - server.Name = req.ServerName - server.Address = req.Host - server.Port = req.Port - server.ReservedSlots = req.ReservedSlots - server.RCON = req.RCON - server.Password = req.Password - server.Latitude = req.Lat - server.Longitude = req.Lon - server.CC = req.CC - server.Region = req.Region - server.IsEnabled = req.IsEnabled - server.LogSecret = req.LogSecret - server.EnableStats = req.EnableStats - - if errSave := h.servers.SaveServer(ctx, &server); errSave != nil { - httphelper.HandleErrInternal(ctx) + req.ServerID = serverID + + server, errSave := h.servers.Save(ctx, req) + if errSave != nil { + httphelper.HandleErrs(ctx, errServerID) slog.Error("Failed to update server", log.ErrAttr(errSave)) return } ctx.JSON(http.StatusOK, server) + slog.Info("Updated server successfully", slog.String("name", server.ShortName)) } } @@ -301,7 +195,7 @@ func (h *serversHandler) onAPIGetServersAdmin() gin.HandlerFunc { IncludeDisabled: true, } - servers, _, errServers := h.servers.GetServers(ctx, filter) + servers, _, errServers := h.servers.Servers(ctx, filter) if errServers != nil { httphelper.HandleErrInternal(ctx) slog.Error("Failed to get servers", log.ErrAttr(errServers)) @@ -321,29 +215,20 @@ func (h *serversHandler) onAPIPostServerDelete() gin.HandlerFunc { return func(ctx *gin.Context) { serverID, errID := httphelper.GetIntParam(ctx, "server_id") if errID != nil { - httphelper.HandleErrBadRequest(ctx) + httphelper.HandleErrs(ctx, errID) slog.Error("Failed to get server_id", log.ErrAttr(errID)) return } - server, errServer := h.servers.GetServer(ctx, serverID) - if errServer != nil { - httphelper.HandleErrInternal(ctx) - slog.Error("Failed to get server", log.ErrAttr(errServer)) + if err := h.servers.Delete(ctx, serverID); err != nil { + httphelper.HandleErrs(ctx, err) + slog.Error("Failed to delete server", log.ErrAttr(err)) return } - server.Deleted = true - - if errSave := h.servers.SaveServer(ctx, &server); errSave != nil { - httphelper.HandleErrInternal(ctx) - slog.Error("Failed to delete server", log.ErrAttr(errSave)) - - return - } - - ctx.JSON(http.StatusOK, server) + ctx.JSON(http.StatusOK, gin.H{}) + slog.Info("Deleted server", slog.Int("server_id", serverID)) } } diff --git a/internal/servers/servers_usecase.go b/internal/servers/servers_usecase.go index 77942854..1deabe8f 100644 --- a/internal/servers/servers_usecase.go +++ b/internal/servers/servers_usecase.go @@ -2,55 +2,90 @@ package servers import ( "context" - "log/slog" + "time" "github.com/leighmacdonald/gbans/internal/domain" ) type serversUsecase struct { - servers domain.ServersRepository + repository domain.ServersRepository +} + +func (s *serversUsecase) Delete(ctx context.Context, serverID int) error { + if serverID <= 0 { + return domain.ErrInvalidParameter + } + + server, errServer := s.Server(ctx, serverID) + if errServer != nil { + return errServer + } + + server.Deleted = true + + return s.repository.SaveServer(ctx, &server) } func NewServersUsecase(repository domain.ServersRepository) domain.ServersUsecase { - return &serversUsecase{servers: repository} + return &serversUsecase{repository: repository} } -func (s *serversUsecase) GetServer(ctx context.Context, serverID int) (domain.Server, error) { +func (s *serversUsecase) Server(ctx context.Context, serverID int) (domain.Server, error) { if serverID <= 0 { return domain.Server{}, domain.ErrGetServer } - return s.servers.GetServer(ctx, serverID) + return s.repository.GetServer(ctx, serverID) } -func (s *serversUsecase) GetServerPermissions(ctx context.Context) ([]domain.ServerPermission, error) { - return s.servers.GetServerPermissions(ctx) +func (s *serversUsecase) ServerPermissions(ctx context.Context) ([]domain.ServerPermission, error) { + return s.repository.GetServerPermissions(ctx) } -func (s *serversUsecase) GetServers(ctx context.Context, filter domain.ServerQueryFilter) ([]domain.Server, int64, error) { - return s.servers.GetServers(ctx, filter) +func (s *serversUsecase) Servers(ctx context.Context, filter domain.ServerQueryFilter) ([]domain.Server, int64, error) { + return s.repository.GetServers(ctx, filter) } -func (s *serversUsecase) GetServerByName(ctx context.Context, serverName string, server *domain.Server, disabledOk bool, deletedOk bool) error { - return s.servers.GetServerByName(ctx, serverName, server, disabledOk, deletedOk) +func (s *serversUsecase) GetByName(ctx context.Context, serverName string, server *domain.Server, disabledOk bool, deletedOk bool) error { + return s.repository.GetServerByName(ctx, serverName, server, disabledOk, deletedOk) } -func (s *serversUsecase) GetServerByPassword(ctx context.Context, serverPassword string, server *domain.Server, disabledOk bool, deletedOk bool) error { - return s.servers.GetServerByPassword(ctx, serverPassword, server, disabledOk, deletedOk) +func (s *serversUsecase) GetByPassword(ctx context.Context, serverPassword string, server *domain.Server, disabledOk bool, deletedOk bool) error { + return s.repository.GetServerByPassword(ctx, serverPassword, server, disabledOk, deletedOk) } -func (s *serversUsecase) SaveServer(ctx context.Context, server *domain.Server) error { - isNew := server.ServerID == 0 +func (s *serversUsecase) Save(ctx context.Context, req domain.RequestServerUpdate) (domain.Server, error) { + var server domain.Server - if err := s.servers.SaveServer(ctx, server); err != nil { - return err + if req.ServerID > 0 { + existingServer, errServer := s.Server(ctx, req.ServerID) + if errServer != nil { + return domain.Server{}, errServer + } + server = existingServer + server.UpdatedOn = time.Now() + } else { + server = domain.NewServer(req.ServerNameShort, req.Host, req.Port) } - if isNew { - slog.Info("Server config created", slog.Int("server_id", server.ServerID), slog.String("name", server.ShortName)) - } else { - slog.Info("Server config updated", slog.Int("server_id", server.ServerID), slog.String("name", server.ShortName), slog.Bool("deleted", server.Deleted)) + server.ShortName = req.ServerNameShort + server.Name = req.ServerName + server.Address = req.Host + server.Port = req.Port + server.ReservedSlots = req.ReservedSlots + server.RCON = req.RCON + server.Password = req.Password + server.Latitude = req.Lat + server.Longitude = req.Lon + server.CC = req.CC + server.Region = req.Region + server.IsEnabled = req.IsEnabled + server.LogSecret = req.LogSecret + server.EnableStats = req.EnableStats + + if err := s.repository.SaveServer(ctx, &server); err != nil { + return domain.Server{}, err } - return nil + return s.Server(ctx, server.ServerID) } diff --git a/internal/srcds/srcds_service.go b/internal/srcds/srcds_service.go index 7af4bfa6..9ce21367 100644 --- a/internal/srcds/srcds_service.go +++ b/internal/srcds/srcds_service.go @@ -1059,7 +1059,7 @@ func (s *srcdsHandler) onAPIPostPingMod() gin.HandlerFunc { return } - server, errServer := s.servers.GetServer(ctx, players[0].ServerID) + server, errServer := s.servers.Server(ctx, players[0].ServerID) if errServer != nil { slog.Error("Failed to load server", log.ErrAttr(errServer)) diff --git a/internal/state/collector.go b/internal/state/collector.go index fa7d8c3e..8ab71c8b 100644 --- a/internal/state/collector.go +++ b/internal/state/collector.go @@ -341,7 +341,7 @@ func (c *Collector) startStatus(ctx context.Context) { } func (c *Collector) updateServerConfigs(ctx context.Context) { - servers, _, errServers := c.serverUsecase.GetServers(ctx, domain.ServerQueryFilter{ + servers, _, errServers := c.serverUsecase.Servers(ctx, domain.ServerQueryFilter{ QueryFilter: domain.QueryFilter{Deleted: false}, IncludeDisabled: false, }) diff --git a/internal/state/state_usecase.go b/internal/state/state_usecase.go index 7df88b7b..0b206f1f 100644 --- a/internal/state/state_usecase.go +++ b/internal/state/state_usecase.go @@ -74,7 +74,7 @@ func (s *stateUsecase) updateSrcdsLogSecrets(ctx context.Context) { defer cancelServers() - servers, _, errServers := s.servers.GetServers(serversCtx, domain.ServerQueryFilter{ + servers, _, errServers := s.servers.Servers(serversCtx, domain.ServerQueryFilter{ IncludeDisabled: false, QueryFilter: domain.QueryFilter{Deleted: false}, }) diff --git a/internal/test/main_test.go b/internal/test/main_test.go index ab1667d7..d84405d5 100644 --- a/internal/test/main_test.go +++ b/internal/test/main_test.go @@ -85,8 +85,8 @@ func TestMain(m *testing.M) { } defer func() { - termCtx, cancel := context.WithTimeout(context.Background(), time.Second*30) - defer cancel() + termCtx, termCancel := context.WithTimeout(context.Background(), time.Second*30) + defer termCancel() if errTerm := container.Terminate(termCtx); errTerm != nil { panic(fmt.Sprintf("Failed to terminate test container: %v", errTerm)) @@ -155,6 +155,7 @@ func testRouter() *gin.Engine { ban.NewBanHandler(router, banSteamUC, discordUC, personUC, configUC, authUC) ban.NewBanNetHandler(router, banNetUC, authUC) ban.NewBanASNHandler(router, banASNUC, authUC) + servers.NewServerHandler(router, serversUC, stateUC, authUC, personUC) steamgroup.NewSteamgroupHandler(router, banGroupUC, authUC) news.NewNewsHandler(router, newsUC, discordUC, authUC) wiki.NewWIkiHandler(router, wikiUC, authUC) diff --git a/internal/test/servers_test.go b/internal/test/servers_test.go new file mode 100644 index 00000000..0d7e41dd --- /dev/null +++ b/internal/test/servers_test.go @@ -0,0 +1,134 @@ +package test_test + +import ( + "fmt" + "net/http" + "testing" + + "github.com/leighmacdonald/gbans/internal/domain" + "github.com/leighmacdonald/gbans/pkg/stringutil" + "github.com/stretchr/testify/require" +) + +func TestServers(t *testing.T) { + router := testRouter() + owner := loginUser(getOwner()) + user := loginUser(getUser()) + + var servers []domain.Server + testEndpointWithReceiver(t, router, http.MethodGet, "/api/servers_admin", nil, http.StatusOK, owner, &servers) + require.Empty(t, servers) + + var safeServers []domain.ServerInfoSafe + testEndpointWithReceiver(t, router, http.MethodGet, "/api/servers", nil, http.StatusOK, user, &safeServers) + require.Empty(t, servers) + + newServer := domain.RequestServerUpdate{ + ServerName: "test-1 long", + ServerNameShort: "test-1", + Host: "1.2.3.4", + Port: 27015, + ReservedSlots: 8, + Password: stringutil.SecureRandomString(8), + RCON: stringutil.SecureRandomString(8), + Lat: 10, + Lon: 10, + CC: "us", + Region: "na", + IsEnabled: true, + EnableStats: false, + LogSecret: 12345678, + } + + var server domain.Server + testEndpointWithReceiver(t, router, http.MethodPost, "/api/servers", newServer, http.StatusOK, owner, &server) + + require.Equal(t, newServer.ServerNameShort, server.ShortName) + require.Equal(t, newServer.ServerName, server.Name) + require.Equal(t, newServer.Host, server.Address) + require.Equal(t, newServer.Port, server.Port) + require.Equal(t, newServer.ReservedSlots, server.ReservedSlots) + require.Equal(t, newServer.Password, server.Password) + require.Equal(t, newServer.RCON, server.RCON) + require.InEpsilon(t, newServer.Lat, server.Latitude, 0.001) + require.InEpsilon(t, newServer.Lon, server.Longitude, 0.001) + require.Equal(t, newServer.CC, server.CC) + require.Equal(t, newServer.Region, server.Region) + require.Equal(t, newServer.IsEnabled, server.IsEnabled) + require.Equal(t, newServer.EnableStats, server.EnableStats) + require.Equal(t, newServer.LogSecret, server.LogSecret) + + testEndpointWithReceiver(t, router, http.MethodGet, "/api/servers_admin", nil, http.StatusOK, owner, &servers) + require.NotEmpty(t, servers) + + testEndpointWithReceiver(t, router, http.MethodGet, "/api/servers", nil, http.StatusOK, user, &safeServers) + require.NotEmpty(t, servers) + + update := domain.RequestServerUpdate{ + ServerName: "test-2 long", + ServerNameShort: "test-2", + Host: "2.3.4.5", + Port: 27016, + ReservedSlots: 5, + Password: stringutil.SecureRandomString(8), + RCON: stringutil.SecureRandomString(8), + Lat: 11, + Lon: 11, + CC: "de", + Region: "eu", + IsEnabled: true, + EnableStats: true, + LogSecret: 23456789, + } + + var updated domain.Server + testEndpointWithReceiver(t, router, http.MethodPost, fmt.Sprintf("/api/servers/%d", server.ServerID), update, http.StatusOK, owner, &updated) + + require.Equal(t, update.ServerNameShort, updated.ShortName) + require.Equal(t, update.ServerName, updated.Name) + require.Equal(t, update.Host, updated.Address) + require.Equal(t, update.Port, updated.Port) + require.Equal(t, update.ReservedSlots, updated.ReservedSlots) + require.Equal(t, update.Password, updated.Password) + require.Equal(t, update.RCON, updated.RCON) + require.InEpsilon(t, update.Lat, updated.Latitude, 0.001) + require.InEpsilon(t, update.Lon, updated.Longitude, 0.001) + require.Equal(t, update.CC, updated.CC) + require.Equal(t, update.Region, updated.Region) + require.Equal(t, update.IsEnabled, updated.IsEnabled) + require.Equal(t, update.EnableStats, updated.EnableStats) + require.Equal(t, update.LogSecret, updated.LogSecret) + + testEndpoint(t, router, http.MethodDelete, fmt.Sprintf("/api/servers/%d", server.ServerID), nil, http.StatusOK, owner) + testEndpoint(t, router, http.MethodDelete, fmt.Sprintf("/api/servers/%d", server.ServerID), nil, http.StatusNotFound, owner) + testEndpoint(t, router, http.MethodDelete, "/api/servers/xx", nil, http.StatusBadRequest, owner) +} + +func TestServersPermissions(t *testing.T) { + testPermissions(t, testRouter(), []permTestValues{ + { + path: "/api/servers", + method: http.MethodPost, + code: http.StatusForbidden, + levels: admin, + }, + { + path: "/api/servers/1", + method: http.MethodPost, + code: http.StatusForbidden, + levels: admin, + }, + { + path: "/api/servers/:server_id", + method: http.MethodDelete, + code: http.StatusForbidden, + levels: admin, + }, + { + path: "/api/servers_admin", + method: http.MethodPost, + code: http.StatusForbidden, + levels: admin, + }, + }) +}