diff --git a/backend/cmd/api/internal/controller/controller.go b/backend/cmd/api/internal/controller/controller.go index fe0ff4f..e0207c3 100644 --- a/backend/cmd/api/internal/controller/controller.go +++ b/backend/cmd/api/internal/controller/controller.go @@ -6,8 +6,7 @@ import ( "io" "net/http" - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgconn" + "github.com/CMS-Enterprise/ztmf/backend/cmd/api/internal/model" ) type response struct { @@ -34,33 +33,28 @@ func respond(w http.ResponseWriter, r *http.Request, data any, err error) { } if err == nil && data == nil { - err = &NotFoundError{} + err = ErrNotFound } if err != nil { - if errors.Is(err, pgx.ErrNoRows) { - err = &NotFoundError{} - } - - if errors.Is(err, pgx.ErrTooManyRows) || errors.Is(err, pgx.ErrTxClosed) || errors.Is(err, pgx.ErrTxCommitRollback) { - err = &ServerError{} - } - - switch err.(type) { - case *pgconn.PgError: - status = 500 - err = &ServerError{} - case *ForbiddenError: + switch { + case errors.Is(err, model.ErrNoData): + status = 404 + err = ErrNotFound + case errors.Is(err, ErrForbidden): status = 403 - case *InvalidInputError: + case errors.Is(err, &model.InvalidInputError{}), errors.Is(err, model.ErrNotUnique): status = 400 - case *NotFoundError: - status = 404 - case *ServerError: - status = 500 + case errors.Is(err, model.ErrDbConnection): + err = ErrServiceUnavailable + status = 503 + case errors.Is(err, model.ErrTooMuchData): + fallthrough default: status = 500 + err = ErrServer } + res.Err = err.Error() } diff --git a/backend/cmd/api/internal/controller/errors.go b/backend/cmd/api/internal/controller/errors.go index b1b7321..cabb0b0 100644 --- a/backend/cmd/api/internal/controller/errors.go +++ b/backend/cmd/api/internal/controller/errors.go @@ -1,30 +1,12 @@ package controller -import "fmt" - -type ForbiddenError struct{} - -func (e *ForbiddenError) Error() string { - return "forbidden" -} - -type InvalidInputError struct { - field string - value any -} - -func (e *InvalidInputError) Error() string { - return fmt.Sprintf("invalid input for field `%s` with value `%s`", e.field, e.value) -} - -type NotFoundError struct{} - -func (e *NotFoundError) Error() string { - return "not found" -} - -type ServerError struct{} - -func (e *ServerError) Error() string { - return "server error" -} +import ( + "errors" +) + +var ( + ErrForbidden = errors.New("forbidden") + ErrNotFound = errors.New("not found") + ErrServer = errors.New("server error") + ErrServiceUnavailable = errors.New("service unavailable") +) diff --git a/backend/cmd/api/internal/controller/fismasystems.go b/backend/cmd/api/internal/controller/fismasystems.go index b9dc2e0..c036974 100644 --- a/backend/cmd/api/internal/controller/fismasystems.go +++ b/backend/cmd/api/internal/controller/fismasystems.go @@ -34,7 +34,7 @@ func GetFismaSystem(w http.ResponseWriter, r *http.Request) { } if !user.IsAdmin() && !user.IsAssignedFismaSystem(*input.FismaSystemID) { - respond(w, r, nil, &ForbiddenError{}) + respond(w, r, nil, ErrForbidden) return } diff --git a/backend/cmd/api/internal/controller/scores.go b/backend/cmd/api/internal/controller/scores.go index ec4972b..5c31c29 100644 --- a/backend/cmd/api/internal/controller/scores.go +++ b/backend/cmd/api/internal/controller/scores.go @@ -50,7 +50,7 @@ func SaveScore(w http.ResponseWriter, r *http.Request) { } if !user.IsAdmin() && !user.IsAssignedFismaSystem(input.FismaSystemID) { - respond(w, r, nil, &ForbiddenError{}) + respond(w, r, nil, ErrForbidden) return } diff --git a/backend/cmd/api/internal/controller/users.go b/backend/cmd/api/internal/controller/users.go index d9ace44..7c46df5 100644 --- a/backend/cmd/api/internal/controller/users.go +++ b/backend/cmd/api/internal/controller/users.go @@ -13,7 +13,7 @@ func ListUsers(w http.ResponseWriter, r *http.Request) { // TODO: replace the repititious admin checks with ACL authdUser := auth.UserFromContext(r.Context()) if !authdUser.IsAdmin() { - respond(w, r, nil, &ForbiddenError{}) + respond(w, r, nil, ErrForbidden) return } @@ -25,14 +25,14 @@ func ListUsers(w http.ResponseWriter, r *http.Request) { func GetUserById(w http.ResponseWriter, r *http.Request) { authdUser := auth.UserFromContext(r.Context()) if !authdUser.IsAdmin() { - respond(w, r, nil, &ForbiddenError{}) + respond(w, r, nil, ErrForbidden) return } vars := mux.Vars(r) ID, ok := vars["userid"] if !ok { - respond(w, r, nil, &InvalidInputError{"id", nil}) + respond(w, r, nil, ErrNotFound) return } @@ -51,7 +51,7 @@ func GetCurrentUser(w http.ResponseWriter, r *http.Request) { func SaveUser(w http.ResponseWriter, r *http.Request) { authdUser := auth.UserFromContext(r.Context()) if !authdUser.IsAdmin() { - respond(w, r, nil, &ForbiddenError{}) + respond(w, r, nil, ErrForbidden) return } @@ -64,18 +64,6 @@ func SaveUser(w http.ResponseWriter, r *http.Request) { return } - err = validateEmail(user.Email) - if err != nil { - respond(w, r, nil, err) - return - } - - err = validateRole(user.Role) - if err != nil { - respond(w, r, nil, err) - return - } - vars := mux.Vars(r) if v, ok := vars["userid"]; ok { user.UserID = v diff --git a/backend/cmd/api/internal/controller/validations.go b/backend/cmd/api/internal/controller/validations.go deleted file mode 100644 index ad920f3..0000000 --- a/backend/cmd/api/internal/controller/validations.go +++ /dev/null @@ -1,25 +0,0 @@ -package controller - -import "net/mail" - -// use of map enables O(1) vs O(N) as would be the case with slices.Contains([]string) -var roles = map[string]bool{ - "ISSO": true, // bool value isnt used, only the ok value is - "ISSM": true, - "ADMIN": true, -} - -func validateEmail(email string) error { - _, err := mail.ParseAddress(email) - if err != nil { - return &InvalidInputError{field: "email", value: email} - } - return nil -} - -func validateRole(role string) error { - if _, ok := roles[role]; !ok { - return &InvalidInputError{field: "role", value: role} - } - return nil -} diff --git a/backend/cmd/api/internal/model/datacalls.go b/backend/cmd/api/internal/model/datacalls.go index add5bcd..8b0a006 100644 --- a/backend/cmd/api/internal/model/datacalls.go +++ b/backend/cmd/api/internal/model/datacalls.go @@ -23,12 +23,12 @@ func FindDataCalls(ctx context.Context) ([]*DataCall, error) { if err != nil { log.Println(err) - return nil, err + return nil, trapError(err) } return pgx.CollectRows(rows, func(row pgx.CollectableRow) (*DataCall, error) { datacall := DataCall{} err := rows.Scan(&datacall.DataCallID, &datacall.DataCall, &datacall.DateCreated, &datacall.Deadline) - return &datacall, err + return &datacall, trapError(err) }) } diff --git a/backend/cmd/api/internal/model/errors.go b/backend/cmd/api/internal/model/errors.go index 30888e4..b8814bf 100644 --- a/backend/cmd/api/internal/model/errors.go +++ b/backend/cmd/api/internal/model/errors.go @@ -1,19 +1,58 @@ package model -import "fmt" +import ( + "errors" + "fmt" + "log" -type InvalidEmailError struct { - email string -} + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +// provide generalized errors to decouple model consumers from db driver +var ( + ErrNoData = errors.New("no data when expected") + ErrTooMuchData = errors.New("more data than expected") + ErrDbConnection = errors.New("db connection error") + ErrNotUnique = errors.New("not unique") +) -func (e *InvalidEmailError) Error() string { - return fmt.Sprintf("invalid email: %s", e.email) +type InvalidInputError struct { + data map[string]string } -type InvalidRoleError struct { - role string +func (e *InvalidInputError) Error() string { + str := "invalid input:\n" + for k, v := range e.data { + str += " " + k + ":" + v + "\n" + } + return str } -func (e *InvalidRoleError) Error() string { - return fmt.Sprintf("invalid role: %s", e.role) +// trapError converts db driver errors into generic model errors +// this allows consumers of the model package to completely decouple from the driver +func trapError(e error) error { + if e == nil { + return nil + } + log.Print(e) + + // switch is the only way to check against custom error types + switch err := e.(type) { + case *pgconn.PgError: + switch err.Code { + case "23505": + return fmt.Errorf("%w : %s", ErrNotUnique, err.Detail) + } + } + + if errors.Is(e, pgx.ErrNoRows) { + return ErrNoData + } + + if errors.Is(e, pgx.ErrTooManyRows) { + return ErrTooMuchData + } + + return errors.New("unknown error") } diff --git a/backend/cmd/api/internal/model/fismasystems.go b/backend/cmd/api/internal/model/fismasystems.go index 208a197..75c476e 100644 --- a/backend/cmd/api/internal/model/fismasystems.go +++ b/backend/cmd/api/internal/model/fismasystems.go @@ -2,7 +2,6 @@ package model import ( "context" - "errors" "log" "github.com/jackc/pgx/v5" @@ -47,19 +46,21 @@ func FindFismaSystems(ctx context.Context, input FindFismaSystemsInput) ([]*Fism if err != nil { log.Println(err) - return nil, err + return nil, trapError(err) } return pgx.CollectRows(rows, func(row pgx.CollectableRow) (*FismaSystem, error) { fismaSystem := FismaSystem{} err := row.Scan(&fismaSystem.FismaSystemID, &fismaSystem.FismaUID, &fismaSystem.FismaAcronym, &fismaSystem.FismaName, &fismaSystem.FismaSubsystem, &fismaSystem.Component, &fismaSystem.Groupacronym, &fismaSystem.GroupName, &fismaSystem.DivisionName, &fismaSystem.DataCenterEnvironment, &fismaSystem.DataCallContact, &fismaSystem.ISSOEmail) - return &fismaSystem, err + return &fismaSystem, trapError(err) }) } func FindFismaSystem(ctx context.Context, input FindFismaSystemsInput) (*FismaSystem, error) { if input.FismaSystemID == nil { - return nil, errors.New("fismasystemid cannot be null") + return nil, &InvalidInputError{ + data: map[string]string{"fismasystemid": "null"}, + } } sqlb := sqlBuilder.Select("fismasystems.fismasystemid as fismasystemid, fismauid, fismaacronym, fismaname, fismasubsystem, component, groupacronym, groupname, divisionname, datacenterenvironment, datacallcontact, issoemail").From("fismasystems") @@ -69,14 +70,14 @@ func FindFismaSystem(ctx context.Context, input FindFismaSystemsInput) (*FismaSy row, err := queryRow(ctx, sql, boundArgs...) if err != nil { log.Println(err) - return nil, err + return nil, trapError(err) } fismaSystem := FismaSystem{} err = row.Scan(&fismaSystem.FismaSystemID, &fismaSystem.FismaUID, &fismaSystem.FismaAcronym, &fismaSystem.FismaName, &fismaSystem.FismaSubsystem, &fismaSystem.Component, &fismaSystem.Groupacronym, &fismaSystem.GroupName, &fismaSystem.DivisionName, &fismaSystem.DataCenterEnvironment, &fismaSystem.DataCallContact, &fismaSystem.ISSOEmail) if err != nil { log.Println(err) - return nil, err + return nil, trapError(err) } return &fismaSystem, nil diff --git a/backend/cmd/api/internal/model/functionoptions.go b/backend/cmd/api/internal/model/functionoptions.go index 7898f07..a1a51f6 100644 --- a/backend/cmd/api/internal/model/functionoptions.go +++ b/backend/cmd/api/internal/model/functionoptions.go @@ -30,12 +30,12 @@ func FindFunctionOptions(ctx context.Context, input FindFunctionOptionsInput) ([ if err != nil { log.Println(err) - return nil, err + return nil, trapError(err) } return pgx.CollectRows(rows, func(row pgx.CollectableRow) (*FunctionOption, error) { fo := FunctionOption{} err := rows.Scan(&fo.FunctionOptionID, &fo.FunctionID, &fo.Score, &fo.OptionName, &fo.Description) - return &fo, err + return &fo, trapError(err) }) } diff --git a/backend/cmd/api/internal/model/model.go b/backend/cmd/api/internal/model/model.go index 9a69e72..29befd2 100644 --- a/backend/cmd/api/internal/model/model.go +++ b/backend/cmd/api/internal/model/model.go @@ -16,7 +16,7 @@ var sqlBuilder = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar) func query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { conn, err := db.Conn(ctx) if err != nil { - return nil, err + return nil, trapError(err) } return conn.Query(ctx, sql, args...) @@ -26,7 +26,7 @@ func query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { func queryRow(ctx context.Context, sql string, args ...any) (pgx.Row, error) { conn, err := db.Conn(ctx) if err != nil { - return nil, err + return nil, trapError(err) } row := conn.QueryRow(ctx, sql, args...) diff --git a/backend/cmd/api/internal/model/questions.go b/backend/cmd/api/internal/model/questions.go index 2c8008b..a780d86 100644 --- a/backend/cmd/api/internal/model/questions.go +++ b/backend/cmd/api/internal/model/questions.go @@ -33,7 +33,7 @@ func FindQuestions(ctx context.Context, input FindQuestionInput) ([]*Question, e if err != nil { log.Println(err) - return nil, err + return nil, trapError(err) } return pgx.CollectRows(rows, func(row pgx.CollectableRow) (*Question, error) { @@ -51,6 +51,6 @@ func FindQuestions(ctx context.Context, input FindQuestionInput) ([]*Question, e } err := rows.Scan(scanFields...) - return &question, err + return &question, trapError(err) }) } diff --git a/backend/cmd/api/internal/model/scores.go b/backend/cmd/api/internal/model/scores.go index cf6854c..83fc179 100644 --- a/backend/cmd/api/internal/model/scores.go +++ b/backend/cmd/api/internal/model/scores.go @@ -59,13 +59,13 @@ func FindScores(ctx context.Context, input FindScoresInput) ([]*Score, error) { if err != nil { log.Println(err) - return nil, err + return nil, trapError(err) } return pgx.CollectRows(rows, func(row pgx.CollectableRow) (*Score, error) { score := Score{} err := row.Scan(&score.ScoreID, &score.FismaSystemID, &score.DateCalculated, &score.Notes, &score.FunctionOptionID, &score.DataCallID) - return &score, err + return &score, trapError(err) }) } @@ -78,13 +78,13 @@ func CreateScore(ctx context.Context, input SaveScoreInput) (*Score, error) { sql, boundArgs, _ := sqlb.ToSql() row, err := queryRow(ctx, sql, boundArgs...) if err != nil { - return nil, err + return nil, trapError(err) } score := Score{} err = row.Scan(&score.ScoreID, &score.FismaSystemID, &score.DateCalculated, &score.Notes, &score.FunctionOptionID, &score.DataCallID) - return &score, err + return &score, trapError(err) } func UpdateScore(ctx context.Context, input SaveScoreInput) error { @@ -126,12 +126,12 @@ func FindScoresAggregate(ctx context.Context, input FindScoresInput) ([]*ScoreAg if err != nil { log.Println(err) - return nil, err + return nil, trapError(err) } return pgx.CollectRows(rows, func(row pgx.CollectableRow) (*ScoreAggregate, error) { sagg := ScoreAggregate{} err := row.Scan(&sagg.DataCallID, &sagg.FismaSystemID, &sagg.SystemScore) - return &sagg, err + return &sagg, trapError(err) }) } diff --git a/backend/cmd/api/internal/model/users.go b/backend/cmd/api/internal/model/users.go index 5542719..5c034a7 100644 --- a/backend/cmd/api/internal/model/users.go +++ b/backend/cmd/api/internal/model/users.go @@ -37,18 +37,21 @@ func FindUsers(ctx context.Context) ([]*User, error) { if err != nil { log.Println(err) - return nil, err + return nil, trapError(err) } return pgx.CollectRows(rows, func(row pgx.CollectableRow) (*User, error) { user := User{} err := rows.Scan(&user.UserID, &user.Email, &user.FullName, &user.Role) - return &user, err + return &user, trapError(err) }) } // FindUserByID queries the database for a User with the given ID and returns *User or error func FindUserByID(ctx context.Context, userid string) (*User, error) { + if !isValidUUID(userid) { + return nil, ErrNoData + } return findUser(ctx, "users.userid=?", []any{userid}) } @@ -62,17 +65,21 @@ func findUser(ctx context.Context, where string, args []any) (*User, error) { sql, boundArgs, _ := sqlb.ToSql() row, err := queryRow(ctx, sql, boundArgs...) if err != nil { - return nil, err + return nil, trapError(err) } // Scan the query result into the User struct u := User{} err = row.Scan(&u.UserID, &u.Email, &u.FullName, &u.Role, &u.AssignedFismaSystems) - return &u, err + return &u, trapError(err) } func CreateUser(ctx context.Context, user User) (*User, error) { + if err := validateUser(user); err != nil { + return nil, err + } + sqlb := sqlBuilder.Insert("users"). Columns("email, fullname, role"). Values(user.Email, user.FullName, user.Role). @@ -81,15 +88,19 @@ func CreateUser(ctx context.Context, user User) (*User, error) { sql, boundArgs, _ := sqlb.ToSql() row, err := queryRow(ctx, sql, boundArgs...) if err != nil { - return nil, err + return nil, trapError(err) } err = row.Scan(&user.UserID) - return &user, err + return &user, trapError(err) } func UpdateUser(ctx context.Context, user User) (*User, error) { + if err := validateUser(user); err != nil { + return nil, err + } + sqlb := sqlBuilder.Update("users"). Set("email", user.Email). Set("fullname", user.FullName). @@ -100,12 +111,36 @@ func UpdateUser(ctx context.Context, user User) (*User, error) { sql, boundArgs, _ := sqlb.ToSql() row, err := queryRow(ctx, sql, boundArgs...) if err != nil { - return nil, err + return nil, trapError(err) } err = row.Scan(&user.UserID, &user.Email, &user.FullName, &user.Role) - return &user, err + return &user, trapError(err) +} + +func validateUser(user User) error { + err := InvalidInputError{data: map[string]string{}} + + if user.UserID != "" { + if !isValidUUID(user.UserID) { + err.data["userid"] = user.UserID + } + } + + if !isValidEmail(user.Email) { + err.data["email"] = user.Email + } + + if !isValidRole(user.Role) { + err.data["role"] = user.Role + } + + if len(err.data) > 0 { + return &err + } + + return nil } // func CreateUserFismaSystems(ctx context.Context, userid string, fismasystemids []int32) error { diff --git a/backend/cmd/api/internal/model/validations.go b/backend/cmd/api/internal/model/validations.go new file mode 100644 index 0000000..3e53e06 --- /dev/null +++ b/backend/cmd/api/internal/model/validations.go @@ -0,0 +1,29 @@ +package model + +import ( + "net/mail" + "regexp" +) + +// use of map enables O(1) vs O(N) as would be the case with slices.Contains([]string) +var roles = map[string]interface{}{ + "ISSO": nil, // the value isn't used, only the ok check value is + "ISSM": nil, + "ADMIN": nil, +} + +var rgxUUID = regexp.MustCompile("^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-4[a-fA-F0-9]{3}-[8|9|aA|bB][a-fA-F0-9]{3}-[a-fA-F0-9]{12}$") + +func isValidEmail(email string) bool { + _, err := mail.ParseAddress(email) + return err == nil +} + +func isValidRole(role string) bool { + _, ok := roles[role] + return ok +} + +func isValidUUID(uuid string) bool { + return rgxUUID.MatchString(uuid) +}