Skip to content

Commit

Permalink
Include origins in GET/POST pool endpoints and pools in GET/POST port…
Browse files Browse the repository at this point in the history
… endpoints (#89)

* centralize create assignment logic, include origins in get/post endpoints for pools

Signed-off-by: E Camden Fisher <efisher@equinix.com>

* fixup linter

Signed-off-by: E Camden Fisher <efisher@equinix.com>

* load assignments for ports

Signed-off-by: E Camden Fisher <efisher@equinix.com>

* cleanup assignments when deleting ports and pools

Signed-off-by: E Camden Fisher <efisher@equinix.com>

---------

Signed-off-by: E Camden Fisher <efisher@equinix.com>
  • Loading branch information
fishnix authored Mar 29, 2023
1 parent de2435d commit 7e940d9
Show file tree
Hide file tree
Showing 21 changed files with 896 additions and 238 deletions.
61 changes: 39 additions & 22 deletions pkg/api/v1/assignments_create.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package api

import (
"context"

"github.com/labstack/echo/v4"
"github.com/volatiletech/sqlboiler/v4/boil"
"github.com/volatiletech/sqlboiler/v4/queries/qm"
Expand Down Expand Up @@ -39,35 +41,18 @@ func (r *Router) assignmentsCreate(c echo.Context) error {
return v1BadRequestResponse(c, err)
}

// validate pool exists
pool, err := models.Pools(
models.PoolWhere.PoolID.EQ(payload.PoolID),
models.PoolWhere.TenantID.EQ(tenantID),
).One(ctx, r.db)
assignmentID, err := r.createAssignment(ctx, r.db, tenantID, port.LoadBalancerID, payload.PoolID, port.PortID)
if err != nil {
r.logger.Error("error fetching pool", zap.Error(err))
r.logger.Error("failed to create assignment", zap.Error(err))
return v1BadRequestResponse(c, err)
}

r.logger.Debug("validated pool exists", zap.Any("pool", pool))

assignment := models.Assignment{
TenantID: tenantID,
PortID: port.PortID,
PoolID: pool.PoolID,
}

if err := assignment.Insert(ctx, r.db, boil.Infer()); err != nil {
r.logger.Error("error inserting assignment", zap.Error(err))
return v1InternalServerErrorResponse(c, err)
}

msg, err := pubsub.NewAssignmentMessage(
someTestJWTURN,
pubsub.NewTenantURN(tenantID),
pubsub.NewAssignmentURN(assignment.AssignmentID),
pubsub.NewAssignmentURN(assignmentID),
pubsub.NewLoadBalancerURN(port.LoadBalancerID),
pubsub.NewPoolURN(pool.PoolID),
pubsub.NewPoolURN(payload.PoolID),
)
if err != nil {
// TODO: add status to reconcile and requeue this
Expand All @@ -79,5 +64,37 @@ func (r *Router) assignmentsCreate(c echo.Context) error {
r.logger.Error("error publishing assignment event", zap.Error(err))
}

return v1AssignmentsCreatedResponse(c, assignment.AssignmentID)
return v1AssignmentsCreatedResponse(c, assignmentID)
}

func (r *Router) createAssignment(ctx context.Context, exec boil.ContextExecutor, tenantID, loadBalancerID, poolID, portID string) (string, error) {
r.logger.Debug("creating assignment",
zap.String("tenant.id", tenantID),
zap.String("loadbalancer.id", loadBalancerID),
zap.String("pool.id", poolID),
zap.String("port.id", portID),
)

// validate pool exists
pool, err := models.Pools(
models.PoolWhere.PoolID.EQ(poolID),
models.PoolWhere.TenantID.EQ(tenantID),
).One(ctx, r.db)
if err != nil {
r.logger.Error("error fetching pool", zap.Error(err))
return "", err
}

assignment := models.Assignment{
TenantID: tenantID,
PortID: portID,
PoolID: pool.PoolID,
}

if err := assignment.Insert(ctx, exec, boil.Infer()); err != nil {
r.logger.Error("error inserting assignment", zap.Error(err))
return "", err
}

return assignment.AssignmentID, nil
}
2 changes: 1 addition & 1 deletion pkg/api/v1/assignments_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func Test_Assignments(t *testing.T) {
method: http.MethodPost,
path: baseURL,
body: fmt.Sprintf(`{"port_id": "%s", "pool_id": "%s"}`, fe.ID, pool.ID),
status: http.StatusInternalServerError,
status: http.StatusBadRequest,
})

doHTTPTest(t, &httpTest{
Expand Down
39 changes: 3 additions & 36 deletions pkg/api/v1/load_balancers_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package api

import (
"context"
"database/sql"

"github.com/google/uuid"
"github.com/gosimple/slug"
Expand Down Expand Up @@ -104,7 +103,7 @@ func (r *Router) loadBalancerCreate(c echo.Context) error {
additionalURNs = append(additionalURNs, pubsub.NewPortURN(portID))

for _, pool := range p.Pools {
assignmentID, err := r.loadBalancerAssignmentCreate(ctx, tx, tenantID, lb.LoadBalancerID, pool, portID)
assignmentID, err := r.createAssignment(ctx, tx, tenantID, lb.LoadBalancerID, pool, portID)
if err != nil {
r.logger.Error("failed to create load balancer assignment, rolling back transaction", zap.Error(err))

Expand Down Expand Up @@ -146,7 +145,7 @@ func (r *Router) loadBalancerCreate(c echo.Context) error {
return v1LoadBalancerCreatedResponse(c, lb.LoadBalancerID)
}

func (r *Router) loadBalancerPortCreate(ctx context.Context, tx *sql.Tx, loadBalancerID string, portName string, portNumber int64) (string, error) {
func (r *Router) loadBalancerPortCreate(ctx context.Context, exec boil.ContextExecutor, loadBalancerID string, portName string, portNumber int64) (string, error) {
r.logger.Debug("creating loadbalancer port",
zap.String("loadbalancer.id", loadBalancerID),
zap.String("port.name", portName),
Expand All @@ -166,42 +165,10 @@ func (r *Router) loadBalancerPortCreate(ctx context.Context, tx *sql.Tx, loadBal
return "", err
}

if err := port.Insert(ctx, tx, boil.Infer()); err != nil {
if err := port.Insert(ctx, exec, boil.Infer()); err != nil {
r.logger.Error("failed to insert port", zap.Error(err))
return "", err
}

return port.PortID, nil
}

func (r *Router) loadBalancerAssignmentCreate(ctx context.Context, tx *sql.Tx, tenantID, loadBalancerID, poolID, portID string) (string, error) {
r.logger.Debug("creating loadbalancer assignment",
zap.String("tenant.id", tenantID),
zap.String("loadbalancer.id", loadBalancerID),
zap.String("pool.id", poolID),
zap.String("port.id", portID),
)

// validate pool exists
pool, err := models.Pools(
models.PoolWhere.PoolID.EQ(poolID),
models.PoolWhere.TenantID.EQ(tenantID),
).One(ctx, r.db)
if err != nil {
r.logger.Error("error fetching pool", zap.Error(err))
return "", err
}

assignment := models.Assignment{
TenantID: tenantID,
PortID: portID,
PoolID: pool.PoolID,
}

if err := assignment.Insert(ctx, tx, boil.Infer()); err != nil {
r.logger.Error("error inserting assignment", zap.Error(err))
return "", err
}

return assignment.AssignmentID, nil
}
14 changes: 12 additions & 2 deletions pkg/api/v1/load_balancers_get.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ func (r *Router) loadBalancerList(c echo.Context) error {
return v1BadRequestResponse(c, err)
}

lbs, err := models.LoadBalancers(append(mods, qm.Load("Ports"))...).All(ctx, r.db)
mods = append(mods,
qm.Load("Ports"),
qm.Load("Ports.Assignments"),
)

lbs, err := models.LoadBalancers(mods...).All(ctx, r.db)
if err != nil {
return v1InternalServerErrorResponse(c, err)
}
Expand All @@ -35,7 +40,12 @@ func (r *Router) loadBalancerGet(c echo.Context) error {
return v1BadRequestResponse(c, err)
}

lbs, err := models.LoadBalancers(append(mods, qm.Load("Ports"))...).All(ctx, r.db)
mods = append(mods,
qm.Load("Ports"),
qm.Load("Ports.Assignments"),
)

lbs, err := models.LoadBalancers(mods...).All(ctx, r.db)
if err != nil {
return v1InternalServerErrorResponse(c, err)
}
Expand Down
58 changes: 57 additions & 1 deletion pkg/api/v1/load_balancers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,20 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/google/uuid"
nats "github.com/nats-io/nats.go"
"github.com/stretchr/testify/assert"
"go.infratographer.com/load-balancer-api/internal/httptools"
"go.infratographer.com/load-balancer-api/internal/pubsub"
"go.infratographer.com/x/pubsubx"
)

const (
loadBalancerSubjectCreate = "com.infratographer.events.load-balancer.create.global"
loadBalancerSubjectDelete = "com.infratographer.events.load-balancer.delete.global"
loadBalancerBaseUrn = "urn:infratographer:load-balancer:"
)

func TestCreateLoadBalancer(t *testing.T) {
Expand All @@ -20,6 +30,26 @@ func TestCreateLoadBalancer(t *testing.T) {
srv := newTestServer(t, nsrv.ClientURL())
defer srv.Close()

// create a pubsub client for subscribing to NATS events
subscriber := newPubSubClient(t, nsrv.ClientURL())
msgChan := make(chan *nats.Msg, 10)

// create a new nats subscription on the server created above
subscription, err := subscriber.ChanSubscribe(
context.TODO(),
"com.infratographer.events.load-balancer.>",
msgChan,
"load-balancer-api-test",
)

assert.NoError(t, err)

defer func() {
if err := subscription.Unsubscribe(); err != nil {
t.Error(err)
}
}()

tenantID := uuid.New().String()
locationID := uuid.New().String()
ipID := uuid.New().String()
Expand Down Expand Up @@ -140,6 +170,19 @@ func TestCreateLoadBalancer(t *testing.T) {

assert.NoError(t, err)

select {
case msg := <-msgChan:
pMsg := &pubsubx.Message{}
err = json.Unmarshal(msg.Data, pMsg)
assert.NoError(t, err)

assert.Equal(t, loadBalancerSubjectCreate, msg.Subject)
assert.Equal(t, someTestJWTURN, pMsg.ActorURN)
assert.Equal(t, pubsub.CreateEventType, pMsg.EventType)
case <-time.After(natsMsgSubTimeout):
t.Error("failed to receive nats message for delete")
}

deleteRequest, err := http.NewRequestWithContext(
context.TODO(),
http.MethodDelete,
Expand All @@ -153,6 +196,19 @@ func TestCreateLoadBalancer(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, deleteResp.StatusCode)
defer deleteResp.Body.Close()

select {
case msg := <-msgChan:
pMsg := &pubsubx.Message{}
err = json.Unmarshal(msg.Data, pMsg)
assert.NoError(t, err)

assert.Equal(t, loadBalancerSubjectDelete, msg.Subject)
assert.Equal(t, someTestJWTURN, pMsg.ActorURN)
assert.Equal(t, pubsub.DeleteEventType, pMsg.EventType)
case <-time.After(natsMsgSubTimeout):
t.Error("failed to receive nats message for delete")
}
})
}
}
Expand Down Expand Up @@ -548,7 +604,7 @@ func createLoadBalancer(t *testing.T, srv *httptest.Server, locationID string) (
resp.Body.Close()
})

return (*loadbalancer.LoadBalancers)[0], func(t *testing.T) {
return (loadbalancer.LoadBalancers)[0], func(t *testing.T) {
test := &httpTest{
name: "delete nemo",
path: baseURL + "?slug=nemo",
Expand Down
61 changes: 40 additions & 21 deletions pkg/api/v1/origins_create.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package api

import (
"context"

"github.com/gosimple/slug"
"github.com/labstack/echo/v4"
"github.com/volatiletech/sqlboiler/v4/boil"
Expand All @@ -16,7 +18,7 @@ func (r *Router) originsCreate(c echo.Context) error {
Disabled bool `json:"disabled"`
Name string `json:"name"`
Target string `json:"target"`
Port int `json:"port"`
Port int64 `json:"port"`
}{}

if err := c.Bind(&payload); err != nil {
Expand All @@ -38,31 +40,16 @@ func (r *Router) originsCreate(c echo.Context) error {
return v1BadRequestResponse(c, err)
}

origin := models.Origin{
Name: payload.Name,
OriginUserSettingDisabled: payload.Disabled,
OriginTarget: payload.Target,
PoolID: pool.PoolID,
Port: int64(payload.Port),
Slug: slug.Make(payload.Name),
CurrentState: "configuring",
}

if err := validateOrigin(origin); err != nil {
r.logger.Error("error validating origins", zap.Error(err))
originID, err := r.createOrigin(ctx, r.db, pool.PoolID, payload.Name, payload.Target, payload.Port, payload.Disabled)
if err != nil {
r.logger.Error("failed to create origins", zap.Error(err))
return v1BadRequestResponse(c, err)
}

if err := origin.Insert(ctx, r.db, boil.Infer()); err != nil {
r.logger.Error("error inserting origins", zap.Error(err))

return v1InternalServerErrorResponse(c, err)
}

msg, err := pubsub.NewOriginMessage(
someTestJWTURN,
pubsub.NewTenantURN(pool.TenantID),
pubsub.NewOriginURN(origin.OriginID),
pubsub.NewOriginURN(originID),
pubsub.NewPoolURN(pool.PoolID),
)
if err != nil {
Expand All @@ -75,7 +62,7 @@ func (r *Router) originsCreate(c echo.Context) error {
r.logger.Error("error publishing origin event", zap.Error(err))
}

return v1OriginCreatedResponse(c, origin.OriginID)
return v1OriginCreatedResponse(c, originID)
}

func validateOrigin(o models.Origin) error {
Expand All @@ -89,3 +76,35 @@ func validateOrigin(o models.Origin) error {

return nil
}

func (r *Router) createOrigin(ctx context.Context, exec boil.ContextExecutor, poolID, name, target string, port int64, disabled bool) (string, error) {
r.logger.Debug("creating pool origin",
zap.String("pool.id", poolID),
zap.String("origin.name", name),
zap.String("origin.target", target),
zap.Int64("origin.port", port),
zap.Bool("origin.disabled", disabled),
)

origin := models.Origin{
Name: name,
OriginUserSettingDisabled: disabled,
OriginTarget: target,
PoolID: poolID,
Port: port,
Slug: slug.Make(name),
CurrentState: "configuring",
}

if err := validateOrigin(origin); err != nil {
r.logger.Error("error validating origins", zap.Error(err))
return "", err
}

if err := origin.Insert(ctx, exec, boil.Infer()); err != nil {
r.logger.Error("error inserting origins", zap.Error(err))
return "", err
}

return origin.OriginID, nil
}
2 changes: 1 addition & 1 deletion pkg/api/v1/origins_param_binding.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (r *Router) originsParamsBinding(c echo.Context) ([]qm.QueryMod, error) {
r.logger.Debug("path param", zap.String("origin.id", originID))
}

queryParams := []string{"slug", "target", "port"}
queryParams := []string{"slug", "target", "port", "origin_id"}

qpb := echo.QueryParamsBinder(c)

Expand Down
Loading

0 comments on commit 7e940d9

Please sign in to comment.