Skip to content

Commit

Permalink
refactor: decouple error handling between controller and db driver (#121
Browse files Browse the repository at this point in the history
)

closes #118
  • Loading branch information
talentedmrjones committed Sep 10, 2024
1 parent a032e7a commit c3c7094
Show file tree
Hide file tree
Showing 15 changed files with 173 additions and 130 deletions.
36 changes: 15 additions & 21 deletions backend/cmd/api/internal/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ import (
"io"
"net/http"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/CMS-Enterprise/ztmf/backend/cmd/api/internal/model"
)

type response struct {
Expand All @@ -34,33 +33,28 @@ func respond(w http.ResponseWriter, r *http.Request, data any, err error) {
}

if err == nil && data == nil {
err = &NotFoundError{}
err = ErrNotFound
}

if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
err = &NotFoundError{}
}

if errors.Is(err, pgx.ErrTooManyRows) || errors.Is(err, pgx.ErrTxClosed) || errors.Is(err, pgx.ErrTxCommitRollback) {
err = &ServerError{}
}

switch err.(type) {
case *pgconn.PgError:
status = 500
err = &ServerError{}
case *ForbiddenError:
switch {
case errors.Is(err, model.ErrNoData):
status = 404
err = ErrNotFound
case errors.Is(err, ErrForbidden):
status = 403
case *InvalidInputError:
case errors.Is(err, &model.InvalidInputError{}), errors.Is(err, model.ErrNotUnique):
status = 400
case *NotFoundError:
status = 404
case *ServerError:
status = 500
case errors.Is(err, model.ErrDbConnection):
err = ErrServiceUnavailable
status = 503
case errors.Is(err, model.ErrTooMuchData):
fallthrough
default:
status = 500
err = ErrServer
}

res.Err = err.Error()
}

Expand Down
38 changes: 10 additions & 28 deletions backend/cmd/api/internal/controller/errors.go
Original file line number Diff line number Diff line change
@@ -1,30 +1,12 @@
package controller

import "fmt"

type ForbiddenError struct{}

func (e *ForbiddenError) Error() string {
return "forbidden"
}

type InvalidInputError struct {
field string
value any
}

func (e *InvalidInputError) Error() string {
return fmt.Sprintf("invalid input for field `%s` with value `%s`", e.field, e.value)
}

type NotFoundError struct{}

func (e *NotFoundError) Error() string {
return "not found"
}

type ServerError struct{}

func (e *ServerError) Error() string {
return "server error"
}
import (
"errors"
)

var (
ErrForbidden = errors.New("forbidden")
ErrNotFound = errors.New("not found")
ErrServer = errors.New("server error")
ErrServiceUnavailable = errors.New("service unavailable")
)
2 changes: 1 addition & 1 deletion backend/cmd/api/internal/controller/fismasystems.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func GetFismaSystem(w http.ResponseWriter, r *http.Request) {
}

if !user.IsAdmin() && !user.IsAssignedFismaSystem(*input.FismaSystemID) {
respond(w, r, nil, &ForbiddenError{})
respond(w, r, nil, ErrForbidden)
return
}

Expand Down
2 changes: 1 addition & 1 deletion backend/cmd/api/internal/controller/scores.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func SaveScore(w http.ResponseWriter, r *http.Request) {
}

if !user.IsAdmin() && !user.IsAssignedFismaSystem(input.FismaSystemID) {
respond(w, r, nil, &ForbiddenError{})
respond(w, r, nil, ErrForbidden)
return
}

Expand Down
20 changes: 4 additions & 16 deletions backend/cmd/api/internal/controller/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func ListUsers(w http.ResponseWriter, r *http.Request) {
// TODO: replace the repititious admin checks with ACL
authdUser := auth.UserFromContext(r.Context())
if !authdUser.IsAdmin() {
respond(w, r, nil, &ForbiddenError{})
respond(w, r, nil, ErrForbidden)
return
}

Expand All @@ -25,14 +25,14 @@ func ListUsers(w http.ResponseWriter, r *http.Request) {
func GetUserById(w http.ResponseWriter, r *http.Request) {
authdUser := auth.UserFromContext(r.Context())
if !authdUser.IsAdmin() {
respond(w, r, nil, &ForbiddenError{})
respond(w, r, nil, ErrForbidden)
return
}

vars := mux.Vars(r)
ID, ok := vars["userid"]
if !ok {
respond(w, r, nil, &InvalidInputError{"id", nil})
respond(w, r, nil, ErrNotFound)
return
}

Expand All @@ -51,7 +51,7 @@ func GetCurrentUser(w http.ResponseWriter, r *http.Request) {
func SaveUser(w http.ResponseWriter, r *http.Request) {
authdUser := auth.UserFromContext(r.Context())
if !authdUser.IsAdmin() {
respond(w, r, nil, &ForbiddenError{})
respond(w, r, nil, ErrForbidden)
return
}

Expand All @@ -64,18 +64,6 @@ func SaveUser(w http.ResponseWriter, r *http.Request) {
return
}

err = validateEmail(user.Email)
if err != nil {
respond(w, r, nil, err)
return
}

err = validateRole(user.Role)
if err != nil {
respond(w, r, nil, err)
return
}

vars := mux.Vars(r)
if v, ok := vars["userid"]; ok {
user.UserID = v
Expand Down
25 changes: 0 additions & 25 deletions backend/cmd/api/internal/controller/validations.go

This file was deleted.

4 changes: 2 additions & 2 deletions backend/cmd/api/internal/model/datacalls.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ func FindDataCalls(ctx context.Context) ([]*DataCall, error) {

if err != nil {
log.Println(err)
return nil, err
return nil, trapError(err)
}

return pgx.CollectRows(rows, func(row pgx.CollectableRow) (*DataCall, error) {
datacall := DataCall{}
err := rows.Scan(&datacall.DataCallID, &datacall.DataCall, &datacall.DateCreated, &datacall.Deadline)
return &datacall, err
return &datacall, trapError(err)
})
}
59 changes: 49 additions & 10 deletions backend/cmd/api/internal/model/errors.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,58 @@
package model

import "fmt"
import (
"errors"
"fmt"
"log"

type InvalidEmailError struct {
email string
}
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)

// provide generalized errors to decouple model consumers from db driver
var (
ErrNoData = errors.New("no data when expected")
ErrTooMuchData = errors.New("more data than expected")
ErrDbConnection = errors.New("db connection error")
ErrNotUnique = errors.New("not unique")
)

func (e *InvalidEmailError) Error() string {
return fmt.Sprintf("invalid email: %s", e.email)
type InvalidInputError struct {
data map[string]string
}

type InvalidRoleError struct {
role string
func (e *InvalidInputError) Error() string {
str := "invalid input:\n"
for k, v := range e.data {
str += " " + k + ":" + v + "\n"
}
return str
}

func (e *InvalidRoleError) Error() string {
return fmt.Sprintf("invalid role: %s", e.role)
// trapError converts db driver errors into generic model errors
// this allows consumers of the model package to completely decouple from the driver
func trapError(e error) error {
if e == nil {
return nil
}
log.Print(e)

// switch is the only way to check against custom error types
switch err := e.(type) {
case *pgconn.PgError:
switch err.Code {
case "23505":
return fmt.Errorf("%w : %s", ErrNotUnique, err.Detail)
}
}

if errors.Is(e, pgx.ErrNoRows) {
return ErrNoData
}

if errors.Is(e, pgx.ErrTooManyRows) {
return ErrTooMuchData
}

return errors.New("unknown error")
}
13 changes: 7 additions & 6 deletions backend/cmd/api/internal/model/fismasystems.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package model

import (
"context"
"errors"
"log"

"github.com/jackc/pgx/v5"
Expand Down Expand Up @@ -47,19 +46,21 @@ func FindFismaSystems(ctx context.Context, input FindFismaSystemsInput) ([]*Fism

if err != nil {
log.Println(err)
return nil, err
return nil, trapError(err)
}

return pgx.CollectRows(rows, func(row pgx.CollectableRow) (*FismaSystem, error) {
fismaSystem := FismaSystem{}
err := row.Scan(&fismaSystem.FismaSystemID, &fismaSystem.FismaUID, &fismaSystem.FismaAcronym, &fismaSystem.FismaName, &fismaSystem.FismaSubsystem, &fismaSystem.Component, &fismaSystem.Groupacronym, &fismaSystem.GroupName, &fismaSystem.DivisionName, &fismaSystem.DataCenterEnvironment, &fismaSystem.DataCallContact, &fismaSystem.ISSOEmail)
return &fismaSystem, err
return &fismaSystem, trapError(err)
})
}

func FindFismaSystem(ctx context.Context, input FindFismaSystemsInput) (*FismaSystem, error) {
if input.FismaSystemID == nil {
return nil, errors.New("fismasystemid cannot be null")
return nil, &InvalidInputError{
data: map[string]string{"fismasystemid": "null"},
}
}

sqlb := sqlBuilder.Select("fismasystems.fismasystemid as fismasystemid, fismauid, fismaacronym, fismaname, fismasubsystem, component, groupacronym, groupname, divisionname, datacenterenvironment, datacallcontact, issoemail").From("fismasystems")
Expand All @@ -69,14 +70,14 @@ func FindFismaSystem(ctx context.Context, input FindFismaSystemsInput) (*FismaSy
row, err := queryRow(ctx, sql, boundArgs...)
if err != nil {
log.Println(err)
return nil, err
return nil, trapError(err)
}

fismaSystem := FismaSystem{}
err = row.Scan(&fismaSystem.FismaSystemID, &fismaSystem.FismaUID, &fismaSystem.FismaAcronym, &fismaSystem.FismaName, &fismaSystem.FismaSubsystem, &fismaSystem.Component, &fismaSystem.Groupacronym, &fismaSystem.GroupName, &fismaSystem.DivisionName, &fismaSystem.DataCenterEnvironment, &fismaSystem.DataCallContact, &fismaSystem.ISSOEmail)
if err != nil {
log.Println(err)
return nil, err
return nil, trapError(err)
}

return &fismaSystem, nil
Expand Down
4 changes: 2 additions & 2 deletions backend/cmd/api/internal/model/functionoptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ func FindFunctionOptions(ctx context.Context, input FindFunctionOptionsInput) ([

if err != nil {
log.Println(err)
return nil, err
return nil, trapError(err)
}

return pgx.CollectRows(rows, func(row pgx.CollectableRow) (*FunctionOption, error) {
fo := FunctionOption{}
err := rows.Scan(&fo.FunctionOptionID, &fo.FunctionID, &fo.Score, &fo.OptionName, &fo.Description)
return &fo, err
return &fo, trapError(err)
})
}
4 changes: 2 additions & 2 deletions backend/cmd/api/internal/model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ var sqlBuilder = squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar)
func query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) {
conn, err := db.Conn(ctx)
if err != nil {
return nil, err
return nil, trapError(err)
}

return conn.Query(ctx, sql, args...)
Expand All @@ -26,7 +26,7 @@ func query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) {
func queryRow(ctx context.Context, sql string, args ...any) (pgx.Row, error) {
conn, err := db.Conn(ctx)
if err != nil {
return nil, err
return nil, trapError(err)
}

row := conn.QueryRow(ctx, sql, args...)
Expand Down
4 changes: 2 additions & 2 deletions backend/cmd/api/internal/model/questions.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func FindQuestions(ctx context.Context, input FindQuestionInput) ([]*Question, e

if err != nil {
log.Println(err)
return nil, err
return nil, trapError(err)
}

return pgx.CollectRows(rows, func(row pgx.CollectableRow) (*Question, error) {
Expand All @@ -51,6 +51,6 @@ func FindQuestions(ctx context.Context, input FindQuestionInput) ([]*Question, e
}

err := rows.Scan(scanFields...)
return &question, err
return &question, trapError(err)
})
}
Loading

0 comments on commit c3c7094

Please sign in to comment.