Skip to content

Commit

Permalink
Refactor backend (#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
glaslos authored Jul 16, 2023
1 parent ec822e0 commit c25c2cd
Show file tree
Hide file tree
Showing 43 changed files with 508 additions and 1,373 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
build:
npm run build
go build -o server main.go
go build server.go

local:
./server localhost:3000 token secret
File renamed without changes.
File renamed without changes.
File renamed without changes.
126 changes: 126 additions & 0 deletions backend/handlers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package backend

import (
"context"
"encoding/json"
"io"
"net/http"
"os"

"github.com/honeynet/ochi/backend/entities"
"github.com/julienschmidt/httprouter"
"google.golang.org/api/idtoken"
)

func (cs *server) indexHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
fh, err := cs.fs.Open("index.html")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if _, err := io.Copy(w, fh); err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
}

func (cs *server) cssHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
fh, err := cs.fs.Open("global.css")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Add("Content-Type", "text/css")
if _, err := io.Copy(w, fh); err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
}

// publishHandler reads the request body with a limit of 8192 bytes and then publishes
// the received message.
func (cs *server) publishHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
body := http.MaxBytesReader(w, r.Body, 8192)
msg, err := io.ReadAll(body)
if err != nil {
http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge)
return
}

cs.publish(msg)

w.WriteHeader(http.StatusAccepted)
}

type response struct {
User entities.User `json:"user,omitempty"`
Token string `json:"token,omitempty"`
}

// sessionHandler creates a new token for the user
func (cs *server) sessionHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
userID := r.Context().Value(userID("userID")).(string)
user, err := cs.uRepo.Get(userID)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

token, err := entities.NewToken(os.Args[3], user)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

w.WriteHeader(http.StatusOK)
if err = json.NewEncoder(w).Encode(response{user, token}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}

// loginHandler validates a token with Google
func (cs *server) loginHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
body := http.MaxBytesReader(w, r.Body, 8192)
data, err := io.ReadAll(body)
if err != nil {
http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge)
return
}

ctx := context.Background()
val, err := idtoken.NewValidator(ctx, idtoken.WithHTTPClient(cs.httpClient))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

payload, err := val.Validate(ctx, string(data), "610036027764-0lveoeejd62j594aqab5e24o2o82r8uf.apps.googleusercontent.com")
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

var user entities.User
if emailInt, ok := payload.Claims["email"]; ok {
if email, ok := emailInt.(string); ok {
user, err = cs.uRepo.Find(email)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
}
}

token, err := entities.NewToken(os.Args[3], user)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

w.WriteHeader(http.StatusOK)
if err = json.NewEncoder(w).Encode(response{user, token}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
50 changes: 50 additions & 0 deletions backend/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package backend

import (
"context"
"net/http"
"strings"

"github.com/honeynet/ochi/backend/entities"
"github.com/julienschmidt/httprouter"
)

func tokenMiddleware(h httprouter.Handle, secret string) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
token, ok := r.URL.Query()["token"]
if !ok || len(token) == 0 || token[0] != secret {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}

h(w, r, ps)
}
}

type userID string

func bearerMiddleware(h httprouter.Handle, secret string) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
authHeader := r.Header.Get("Authentication")
authFields := strings.Fields(authHeader)
if len(authFields) != 2 || strings.ToLower(authFields[0]) != "bearer" {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
token := authFields[1]

claims, valid, err := entities.ValidateToken(token, secret)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if !valid {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}

r = r.WithContext(context.WithValue(r.Context(), userID("userID"), claims.UserID))

h(w, r, ps)
}
}
2 changes: 1 addition & 1 deletion repos/user.go → backend/repos/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"database/sql"
"strings"

"github.com/honeynet/ochi/entities"
"github.com/honeynet/ochi/backend/entities"

"github.com/google/uuid"
"github.com/jmoiron/sqlx"
Expand Down
File renamed without changes.
33 changes: 33 additions & 0 deletions backend/routes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package backend

import (
"io/fs"
"net/http"
"os"

"github.com/julienschmidt/httprouter"
)

func newRouter(cs *server) (*httprouter.Router, error) {
r := httprouter.New()

// static
r.GET("/", cs.indexHandler)
r.GET("/global.css", cs.cssHandler)

build, err := fs.Sub(cs.fs, "build")
if err != nil {
return nil, err
}
r.ServeFiles("/build/*filepath", http.FS(build))

// websocket
r.GET("/subscribe", cs.subscribeHandler)
r.POST("/publish", tokenMiddleware(cs.publishHandler, os.Args[2]))

// user
r.POST("/login", cs.loginHandler)
r.GET("/session", bearerMiddleware(cs.sessionHandler, os.Args[3]))

return r, nil
}
133 changes: 133 additions & 0 deletions backend/server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package backend

import (
"context"
"errors"
"io/fs"
"log"
"net"
"net/http"
"os"
"os/signal"
"sync"
"time"

"github.com/honeynet/ochi/backend/repos"
"nhooyr.io/websocket"

"github.com/jmoiron/sqlx"
"github.com/julienschmidt/httprouter"
"golang.org/x/time/rate"
)

type server struct {
// subscriberMessageBuffer controls the max number
// of messages that can be queued for a subscriber
// before it is kicked.
//
// Defaults to 16.
subscriberMessageBuffer int

// publishLimiter controls the rate limit applied to the publish endpoint.
//
// Defaults to one publish every 100ms with a burst of 8.
publishLimiter *rate.Limiter

// mux routes the various endpoints to the appropriate handler.
mux *httprouter.Router

subscribersMu sync.Mutex
subscribers map[*subscriber]struct{}

// the repositories
uRepo *repos.UserRepo

// http client
httpClient *http.Client

fs fs.FS
}

// NewServer constructs a server with the defaults.
func NewServer(fsys fs.FS) (*server, error) {
cs := &server{
subscriberMessageBuffer: 16,
subscribers: make(map[*subscriber]struct{}),
publishLimiter: rate.NewLimiter(rate.Every(time.Millisecond*100), 8),
httpClient: &http.Client{
Timeout: time.Second,
Transport: &http.Transport{
TLSHandshakeTimeout: time.Second,
},
},
fs: fsys,
}

db, err := sqlx.Connect("sqlite3", "./data.db")
if err != nil {
log.Fatal(err)
}

cs.uRepo, err = repos.NewUserRepo(db)
if err != nil {
log.Fatal(err)
}

cs.mux, err = newRouter(cs)
if err != nil {
return nil, err
}

return cs, nil
}

func (cs *server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
cs.mux.ServeHTTP(w, r)
}

func writeTimeout(ctx context.Context, timeout time.Duration, c *websocket.Conn, msg []byte) error {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

return c.Write(ctx, websocket.MessageText, msg)
}

// run initializes the server
func (cs *server) Run() error {
if len(os.Args) < 4 {
return errors.New("please provide an address to listen on as the first argument, token second, secret third")
}

l, err := net.Listen("tcp", os.Args[1])
if err != nil {
return err
}
log.Printf("listening on http://%v", l.Addr())

srv := &http.Server{
Handler: cs,
ReadTimeout: time.Second * 10,
WriteTimeout: time.Second * 10,
}

defer cs.uRepo.Close()

errc := make(chan error, 1)
go func() {
errc <- srv.Serve(l)
}()

sigs := make(chan os.Signal, 1)
signal.Notify(sigs, os.Interrupt)
select {
case err := <-errc:
log.Printf("failed to serve: %v", err)
case sig := <-sigs:
log.Printf("terminating: %v", sig)
}

ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()

return srv.Shutdown(ctx)
}
2 changes: 1 addition & 1 deletion main_test.go → backend/server_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package main
package backend

import "testing"

Expand Down
Loading

0 comments on commit c25c2cd

Please sign in to comment.