Skip to content

Commit

Permalink
Add register page when there are no admin available
Browse files Browse the repository at this point in the history
  • Loading branch information
RadhiFadlillah committed Mar 24, 2020
1 parent f8aeb42 commit c1f27ed
Show file tree
Hide file tree
Showing 20 changed files with 541 additions and 90 deletions.
66 changes: 55 additions & 11 deletions internal/backend/api/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func (h *Handler) SelectUsers(w http.ResponseWriter, r *http.Request, ps httprou
// Fetch from database
users := []model.User{}
err := h.db.Select(&users,
`SELECT id, username, name FROM user ORDER BY name`)
`SELECT id, username, name, admin FROM user ORDER BY name`)
checkError(err)

// Return list of users
Expand All @@ -31,9 +31,6 @@ func (h *Handler) SelectUsers(w http.ResponseWriter, r *http.Request, ps httprou

// InsertUser is handler for POST /api/user
func (h *Handler) InsertUser(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
// Make sure session still valid
h.auth.MustAuthenticateUser(r)

// Decode request
var user model.User
err := json.NewDecoder(r.Body).Decode(&user)
Expand All @@ -48,20 +45,53 @@ func (h *Handler) InsertUser(w http.ResponseWriter, r *http.Request, ps httprout
panic(fmt.Errorf("username must not empty"))
}

// Generate password
user.Password = randomString(10)
// Generate password if needed
if user.Password == "" {
user.Password = randomString(10)
}

// Start transaction
// Make sure to rollback if panic ever happened
tx := h.db.MustBegin()

defer func() {
if r := recover(); r != nil {
tx.Rollback()
panic(r)
}
}()

// Prepare statements
stmtCountAdmin, err := tx.Preparex(`SELECT COUNT(id)
FROM user WHERE admin = 1`)
checkError(err)

stmtInsert, err := tx.Preparex(`INSERT INTO user
(username, name, password, admin) VALUES (?, ?, ?, ?)`)
checkError(err)

// If admin already exists, make sure session still valid
var nAdmin int
err = stmtCountAdmin.Get(&nAdmin)
checkError(err)

if nAdmin > 0 {
h.auth.MustAuthenticateUser(r)
}

// Hash password with bcrypt
password := []byte(user.Password)
hashedPassword, err := bcrypt.GenerateFromPassword(password, 10)
checkError(err)

// Insert user to database
res := h.db.MustExec(`INSERT INTO user
(username, name, password) VALUES (?, ?, ?)`,
user.Username, user.Name, hashedPassword)
res := stmtInsert.MustExec(user.Username, user.Name, hashedPassword, user.Admin)
user.ID, _ = res.LastInsertId()

// Commit transaction
err = tx.Commit()
checkError(err)

// Return inserted user
w.Header().Add("Content-Encoding", "gzip")
w.Header().Add("Content-Type", "application/json")
Expand Down Expand Up @@ -97,6 +127,9 @@ func (h *Handler) DeleteUsers(w http.ResponseWriter, r *http.Request, ps httprou
stmtDelete, err := tx.Preparex(`DELETE FROM user WHERE id = ?`)
checkError(err)

stmtCountAdmin, err := tx.Preparex(`SELECT COUNT(id) FROM user WHERE admin = 1`)
checkError(err)

// Delete from database
for _, id := range ids {
var username string
Expand All @@ -110,6 +143,15 @@ func (h *Handler) DeleteUsers(w http.ResponseWriter, r *http.Request, ps httprou
h.auth.MassLogout(username)
}

// Make sure at least one admin exists
var nAdmin int
err = stmtCountAdmin.Get(&nAdmin)
checkError(err)

if nAdmin == 0 {
panic(fmt.Errorf("at least one admin must exists"))
}

// Commit transaction
err = tx.Commit()
checkError(err)
Expand All @@ -135,8 +177,10 @@ func (h *Handler) UpdateUser(w http.ResponseWriter, r *http.Request, ps httprout
}

// Update user in database
h.db.MustExec(`UPDATE user SET username = ?, name = ? WHERE id = ?`,
user.Username, user.Name, user.ID)
h.db.MustExec(`UPDATE user
SET username = ?, name = ?, admin = ?
WHERE id = ?`,
user.Username, user.Name, user.Admin, user.ID)

// Return updated user
w.Header().Add("Content-Encoding", "gzip")
Expand Down
44 changes: 13 additions & 31 deletions internal/backend/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,50 +46,32 @@ func (auth *Authenticator) Login(username, password string) (string, model.User,
defer tx.Rollback()

// Prepare statements
stmtGetUserCount, err := tx.Preparex(`
SELECT COUNT(id) FROM user`)
if err != nil {
return "", emptyUser, fmt.Errorf("failed to prepare query: %w", err)
}

stmtGetUser, err := tx.Preparex(`
SELECT id, username, name, password
SELECT id, username, name, password, admin
FROM user WHERE username = ?`)
if err != nil {
return "", emptyUser, fmt.Errorf("failed to prepare query: %w", err)
}

// Get count of user
var nUser int
err = stmtGetUserCount.Get(&nUser)
if err != nil {
return "", emptyUser, fmt.Errorf("failed to get user count: %w", err)
}

// Get user from database
// Fetch user from database
var user model.User
if nUser > 0 {
err = stmtGetUser.Get(&user, username)
if err != nil && err != sql.ErrNoRows {
return "", emptyUser, fmt.Errorf("failed to get user: %w", err)
}
err = stmtGetUser.Get(&user, username)
if err != nil && err != sql.ErrNoRows {
return "", emptyUser, fmt.Errorf("failed to get user: %w", err)
}

if err == sql.ErrNoRows {
return "", emptyUser, fmt.Errorf("user doesn't exist")
}
if err == sql.ErrNoRows {
return "", emptyUser, fmt.Errorf("user doesn't exist")
}

err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
if err != nil {
return "", emptyUser, fmt.Errorf("username and password don't match")
}
// Make sure its password matched.
err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
if err != nil {
return "", emptyUser, fmt.Errorf("username and password don't match")
}

// Save user to session manager
expTime := time.Duration(0)
if user.ID == 0 {
expTime = 15 * time.Minute
}

session, err := auth.sessionManager.RegisterUser(user, expTime)
if err != nil {
return "", emptyUser, fmt.Errorf("failed to register user: %w", err)
Expand Down
3 changes: 2 additions & 1 deletion internal/backend/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func ServeApp(db *sqlx.DB, port int) error {
return fmt.Errorf("failed to create authenticator: %w", err)
}

uiHdl, err := ui.NewHandler(auth)
uiHdl, err := ui.NewHandler(db, auth)
if err != nil {
return fmt.Errorf("failed to create UI handler: %w", err)
}
Expand All @@ -54,6 +54,7 @@ func ServeApp(db *sqlx.DB, port int) error {

router.GET("/", uiHdl.ServeIndex)
router.GET("/login", uiHdl.ServeLogin)
router.GET("/register", uiHdl.ServeRegister)
router.GET("/js/*filepath", uiHdl.ServeJsFile)
router.GET("/res/*filepath", uiHdl.ServeFile)
router.GET("/css/*filepath", uiHdl.ServeFile)
Expand Down
Loading

0 comments on commit c1f27ed

Please sign in to comment.