Skip to content

Commit

Permalink
Feat: Add warm up endpoint to both endpoints and taskqueues
Browse files Browse the repository at this point in the history
  • Loading branch information
jsun-m committed Dec 26, 2024
1 parent b36b742 commit 32ec9b8
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 5 deletions.
3 changes: 3 additions & 0 deletions pkg/abstractions/endpoint/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,9 @@ func (rb *RequestBuffer) handleHttpRequest(req *request, c container) {
}

containerUrl := fmt.Sprintf("http://%s/%s", c.address, req.ctx.Param("subPath"))
if req.task.msg.NoOp {
containerUrl = fmt.Sprintf("http://%s/beta9/warmup", c.address)
}

// Forward query params to the container if ASGI
if rb.isASGI {
Expand Down
2 changes: 2 additions & 0 deletions pkg/abstractions/endpoint/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ func (es *HttpEndpointService) forwardRequest(
ctx echo.Context,
authInfo *auth.AuthInfo,
stubId string,
noOp bool,
) error {
instance, err := es.getOrCreateEndpointInstance(ctx.Request().Context(), stubId)
if err != nil {
Expand Down Expand Up @@ -204,6 +205,7 @@ func (es *HttpEndpointService) forwardRequest(
MaxRetries: 0,
Timeout: instance.StubConfig.TaskPolicy.Timeout,
Expires: time.Now().Add(time.Duration(ttl) * time.Second),
NoOp: noOp,
})
if err != nil {
return err
Expand Down
51 changes: 49 additions & 2 deletions pkg/abstractions/endpoint/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ func registerEndpointRoutes(g *echo.Group, es *HttpEndpointService) *endpointGro
g.GET("/:deploymentName/v:version", auth.WithAuth(group.endpointRequest))
g.GET("/public/:stubId", auth.WithAssumedStubAuth(group.endpointRequest, group.es.isPublic))

g.POST("/id/:stubId/warmup", auth.WithAuth(group.warmUpEndpoint))
g.POST("/:deploymentName/warmup", auth.WithAuth(group.warmUpEndpoint))
g.POST("/:deploymentName/latest/warmup", auth.WithAuth(group.warmUpEndpoint))
g.POST("/:deploymentName/v:version/warmup", auth.WithAuth(group.warmUpEndpoint))

return group
}

Expand Down Expand Up @@ -88,7 +93,7 @@ func (g *endpointGroup) endpointRequest(ctx echo.Context) error {
stubId = deployment.Stub.ExternalId
}

return g.es.forwardRequest(ctx, cc.AuthInfo, stubId)
return g.es.forwardRequest(ctx, cc.AuthInfo, stubId, false)
}

func (g *endpointGroup) ASGIRequest(ctx echo.Context) error {
Expand Down Expand Up @@ -130,5 +135,47 @@ func (g *endpointGroup) ASGIRequest(ctx echo.Context) error {
stubId = deployment.Stub.ExternalId
}

return g.es.forwardRequest(ctx, cc.AuthInfo, stubId)
return g.es.forwardRequest(ctx, cc.AuthInfo, stubId, false)
}

func (g *endpointGroup) warmUpEndpoint(ctx echo.Context) error {
cc, _ := ctx.(*auth.HttpAuthContext)

stubId := ctx.Param("stubId")
deploymentName := ctx.Param("deploymentName")
version := ctx.Param("version")

if deploymentName != "" {
var deployment *types.DeploymentWithRelated

if version == "" {
var err error
deployment, err = g.es.backendRepo.GetLatestDeploymentByName(ctx.Request().Context(), cc.AuthInfo.Workspace.Id, deploymentName, types.StubTypeEndpointDeployment, true)
if err != nil {
return apiv1.HTTPBadRequest("Invalid deployment")
}
} else {
version, err := strconv.Atoi(version)
if err != nil {
return apiv1.HTTPBadRequest("Invalid deployment version")
}

deployment, err = g.es.backendRepo.GetDeploymentByNameAndVersion(ctx.Request().Context(), cc.AuthInfo.Workspace.Id, deploymentName, uint(version), types.StubTypeEndpointDeployment)
if err != nil {
return apiv1.HTTPBadRequest("Invalid deployment")
}
}

if deployment == nil {
return apiv1.HTTPBadRequest("Invalid deployment")
}

if !deployment.Active {
return apiv1.HTTPBadRequest("Deployment is not active")
}

stubId = deployment.Stub.ExternalId
}

return g.es.forwardRequest(ctx, cc.AuthInfo, stubId, true)
}
73 changes: 72 additions & 1 deletion pkg/abstractions/taskqueue/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ func registerTaskQueueRoutes(g *echo.Group, tq *RedisTaskQueue) *taskQueueGroup
g.POST("/:deploymentName/v:version", auth.WithAuth(group.TaskQueuePut))
g.POST("/public/:stubId", auth.WithAssumedStubAuth(group.TaskQueuePut, group.tq.isPublic))

/*
g.POST("/id/:stubId/warmup", auth.WithAuth(group.warmUpEndpoint))
g.POST("/:deploymentName/warmup", auth.WithAuth(group.warmUpEndpoint))
g.POST("/:deploymentName/latest/warmup", auth.WithAuth(group.warmUpEndpoint))
g.POST("/:deploymentName/v:version/warmup", auth.WithAuth(group.warmUpEndpoint))
*/
g.POST("/id/:stubId/warmup", auth.WithAuth(group.TaskQueueWarmUp))
g.POST("/:deploymentName/warmup", auth.WithAuth(group.TaskQueueWarmUp))
g.POST("/:deploymentName/latest/warmup", auth.WithAuth(group.TaskQueueWarmUp))
g.POST("/:deploymentName/v:version/warmup", auth.WithAuth(group.TaskQueueWarmUp))

return group
}

Expand Down Expand Up @@ -75,7 +86,7 @@ func (g *taskQueueGroup) TaskQueuePut(ctx echo.Context) error {
})
}

taskId, err := g.tq.put(ctx.Request().Context(), cc.AuthInfo, stubId, payload)
taskId, err := g.tq.put(ctx.Request().Context(), cc.AuthInfo, stubId, payload, false)
if err != nil {
if _, ok := err.(*types.ErrExceededTaskLimit); ok {
return ctx.JSON(http.StatusTooManyRequests, map[string]interface{}{
Expand All @@ -92,3 +103,63 @@ func (g *taskQueueGroup) TaskQueuePut(ctx echo.Context) error {
"task_id": taskId,
})
}

func (g *taskQueueGroup) TaskQueueWarmUp(ctx echo.Context) error {
cc, _ := ctx.(*auth.HttpAuthContext)

stubId := ctx.Param("stubId")
deploymentName := ctx.Param("deploymentName")
version := ctx.Param("version")

if deploymentName != "" {
var deployment *types.DeploymentWithRelated

if version == "" {
var err error
deployment, err = g.tq.backendRepo.GetLatestDeploymentByName(ctx.Request().Context(), cc.AuthInfo.Workspace.Id, deploymentName, types.StubTypeTaskQueueDeployment, true)
if err != nil {
return apiv1.HTTPBadRequest("Invalid deployment")
}
} else {
version, err := strconv.Atoi(version)
if err != nil {
return apiv1.HTTPBadRequest("Invalid deployment version")
}

deployment, err = g.tq.backendRepo.GetDeploymentByNameAndVersion(ctx.Request().Context(), cc.AuthInfo.Workspace.Id, deploymentName, uint(version), types.StubTypeTaskQueueDeployment)
if err != nil {
return apiv1.HTTPBadRequest("Invalid deployment")
}
}

if deployment == nil {
return apiv1.HTTPBadRequest("Invalid deployment")
}

if !deployment.Active {
return apiv1.HTTPBadRequest("Deployment is not active")
}

stubId = deployment.Stub.ExternalId
}

taskId, err := g.tq.put(
ctx.Request().Context(),
cc.AuthInfo,
stubId,
&types.TaskPayload{
Args: nil,
Kwargs: make(map[string]interface{}),
},
true,
)
if err != nil {
return ctx.JSON(http.StatusInternalServerError, map[string]string{
"error": err.Error(),
})
}

return ctx.JSON(http.StatusOK, map[string]interface{}{
"task_id": taskId,
})
}
23 changes: 21 additions & 2 deletions pkg/abstractions/taskqueue/taskqueue.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ func (tq *RedisTaskQueue) getStubConfig(stubId string) (*types.StubConfigV1, err
return config, nil
}

func (tq *RedisTaskQueue) put(ctx context.Context, authInfo *auth.AuthInfo, stubId string, payload *types.TaskPayload) (string, error) {
func (tq *RedisTaskQueue) put(ctx context.Context, authInfo *auth.AuthInfo, stubId string, payload *types.TaskPayload, noOp bool) (string, error) {
stubConfig, err := tq.getStubConfig(stubId)
if err != nil {
return "", err
Expand All @@ -206,6 +206,10 @@ func (tq *RedisTaskQueue) put(ctx context.Context, authInfo *auth.AuthInfo, stub
}
policy.Expires = time.Now().Add(time.Duration(policy.TTL) * time.Second)

if noOp {
policy.NoOp = true
}

task, err := tq.taskDispatcher.SendAndExecute(ctx, string(types.ExecutorTaskQueue), authInfo, stubId, payload, policy)
if err != nil {
return "", err
Expand All @@ -226,7 +230,7 @@ func (tq *RedisTaskQueue) TaskQueuePut(ctx context.Context, in *pb.TaskQueuePutR
}, nil
}

taskId, err := tq.put(ctx, authInfo, in.StubId, &payload)
taskId, err := tq.put(ctx, authInfo, in.StubId, &payload, false)
return &pb.TaskQueuePutResponse{
Ok: err == nil,
TaskId: taskId,
Expand Down Expand Up @@ -262,6 +266,21 @@ func (tq *RedisTaskQueue) TaskQueuePop(ctx context.Context, in *pb.TaskQueuePopR
continue
}

if tm.NoOp {
tq.TaskQueueComplete(
ctx,
&pb.TaskQueueCompleteRequest{
TaskId: tm.TaskId,
StubId: tm.StubId,
ContainerId: in.ContainerId,
TaskStatus: string(types.TaskStatusComplete),
KeepWarmSeconds: float32(instance.StubConfig.KeepWarmSeconds),
TaskDuration: 0,
},
)
continue
}

t, err := tq.backendRepo.GetTaskWithRelated(ctx, tm.TaskId)
if err != nil {
continue
Expand Down
1 change: 1 addition & 0 deletions pkg/task/dispatch.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ func (d *Dispatcher) Send(ctx context.Context, executor string, authInfo *auth.A
taskMessage.Kwargs = payload.Kwargs
taskMessage.Policy = policy
taskMessage.Timestamp = time.Now().Unix()
taskMessage.NoOp = policy.NoOp

taskFactory, exists := d.executors.Get(executor)
if !exists {
Expand Down
2 changes: 2 additions & 0 deletions pkg/types/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ type TaskMessage struct {
Policy TaskPolicy `json:"policy" redis:"policy"`
Retries uint `json:"retries" redis:"retries"`
Timestamp int64 `json:"timestamp" redis:"timestamp"`
NoOp bool `json:"no_op" redis:"no_op"`
}

func (tm *TaskMessage) Reset() {
Expand Down Expand Up @@ -119,6 +120,7 @@ type TaskPolicy struct {
Timeout int `json:"timeout" redis:"timeout"`
Expires time.Time `json:"expires" redis:"expires"`
TTL uint32 `json:"ttl" redis:"ttl"`
NoOp bool `json:"no_op" redis:"no_op"`
}

type ErrExceededTaskLimit struct {
Expand Down
8 changes: 8 additions & 0 deletions sdk/src/beta9/runner/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,14 @@ async def function(

return self._create_response(body=task_lifecycle_data.result, status_code=status_code)

@self.app.post("/beta9/warmup")
async def warmup(
request: Request,
):
payload = await request.json()
task_id = payload.get("task_id")
return self._create_response(body=task_id, status_code=HTTPStatus.OK)

def _create_response(self, *, body: Any, status_code: int = HTTPStatus.OK) -> Response:
if isinstance(body, Response):
return body
Expand Down

0 comments on commit 32ec9b8

Please sign in to comment.