diff --git a/pkg/abstractions/endpoint/buffer.go b/pkg/abstractions/endpoint/buffer.go index 1f24ee6d1..ee339c2c0 100644 --- a/pkg/abstractions/endpoint/buffer.go +++ b/pkg/abstractions/endpoint/buffer.go @@ -434,6 +434,9 @@ func (rb *RequestBuffer) handleHttpRequest(req *request, c container) { } containerUrl := fmt.Sprintf("http://%s/%s", c.address, req.ctx.Param("subPath")) + if req.task.msg.NoOp { + containerUrl = fmt.Sprintf("http://%s/beta9/warmup", c.address) + } // Forward query params to the container if ASGI if rb.isASGI { diff --git a/pkg/abstractions/endpoint/endpoint.go b/pkg/abstractions/endpoint/endpoint.go index 185e2f22f..21e7e355e 100644 --- a/pkg/abstractions/endpoint/endpoint.go +++ b/pkg/abstractions/endpoint/endpoint.go @@ -164,6 +164,7 @@ func (es *HttpEndpointService) forwardRequest( ctx echo.Context, authInfo *auth.AuthInfo, stubId string, + noOp bool, ) error { instance, err := es.getOrCreateEndpointInstance(ctx.Request().Context(), stubId) if err != nil { @@ -204,6 +205,7 @@ func (es *HttpEndpointService) forwardRequest( MaxRetries: 0, Timeout: instance.StubConfig.TaskPolicy.Timeout, Expires: time.Now().Add(time.Duration(ttl) * time.Second), + NoOp: noOp, }) if err != nil { return err diff --git a/pkg/abstractions/endpoint/http.go b/pkg/abstractions/endpoint/http.go index 2b84578ab..47a971b17 100644 --- a/pkg/abstractions/endpoint/http.go +++ b/pkg/abstractions/endpoint/http.go @@ -29,6 +29,11 @@ func registerEndpointRoutes(g *echo.Group, es *HttpEndpointService) *endpointGro g.GET("/:deploymentName/v:version", auth.WithAuth(group.endpointRequest)) g.GET("/public/:stubId", auth.WithAssumedStubAuth(group.endpointRequest, group.es.isPublic)) + g.POST("/id/:stubId/warmup", auth.WithAuth(group.warmUpEndpoint)) + g.POST("/:deploymentName/warmup", auth.WithAuth(group.warmUpEndpoint)) + g.POST("/:deploymentName/latest/warmup", auth.WithAuth(group.warmUpEndpoint)) + g.POST("/:deploymentName/v:version/warmup", auth.WithAuth(group.warmUpEndpoint)) + return group } @@ -88,7 +93,7 @@ func (g *endpointGroup) endpointRequest(ctx echo.Context) error { stubId = deployment.Stub.ExternalId } - return g.es.forwardRequest(ctx, cc.AuthInfo, stubId) + return g.es.forwardRequest(ctx, cc.AuthInfo, stubId, false) } func (g *endpointGroup) ASGIRequest(ctx echo.Context) error { @@ -130,5 +135,47 @@ func (g *endpointGroup) ASGIRequest(ctx echo.Context) error { stubId = deployment.Stub.ExternalId } - return g.es.forwardRequest(ctx, cc.AuthInfo, stubId) + return g.es.forwardRequest(ctx, cc.AuthInfo, stubId, false) +} + +func (g *endpointGroup) warmUpEndpoint(ctx echo.Context) error { + cc, _ := ctx.(*auth.HttpAuthContext) + + stubId := ctx.Param("stubId") + deploymentName := ctx.Param("deploymentName") + version := ctx.Param("version") + + if deploymentName != "" { + var deployment *types.DeploymentWithRelated + + if version == "" { + var err error + deployment, err = g.es.backendRepo.GetLatestDeploymentByName(ctx.Request().Context(), cc.AuthInfo.Workspace.Id, deploymentName, types.StubTypeEndpointDeployment, true) + if err != nil { + return apiv1.HTTPBadRequest("Invalid deployment") + } + } else { + version, err := strconv.Atoi(version) + if err != nil { + return apiv1.HTTPBadRequest("Invalid deployment version") + } + + deployment, err = g.es.backendRepo.GetDeploymentByNameAndVersion(ctx.Request().Context(), cc.AuthInfo.Workspace.Id, deploymentName, uint(version), types.StubTypeEndpointDeployment) + if err != nil { + return apiv1.HTTPBadRequest("Invalid deployment") + } + } + + if deployment == nil { + return apiv1.HTTPBadRequest("Invalid deployment") + } + + if !deployment.Active { + return apiv1.HTTPBadRequest("Deployment is not active") + } + + stubId = deployment.Stub.ExternalId + } + + return g.es.forwardRequest(ctx, cc.AuthInfo, stubId, true) } diff --git a/pkg/abstractions/taskqueue/http.go b/pkg/abstractions/taskqueue/http.go index 8a399c762..e69dea767 100644 --- a/pkg/abstractions/taskqueue/http.go +++ b/pkg/abstractions/taskqueue/http.go @@ -26,6 +26,11 @@ func registerTaskQueueRoutes(g *echo.Group, tq *RedisTaskQueue) *taskQueueGroup g.POST("/:deploymentName/v:version", auth.WithAuth(group.TaskQueuePut)) g.POST("/public/:stubId", auth.WithAssumedStubAuth(group.TaskQueuePut, group.tq.isPublic)) + g.POST("/id/:stubId/warmup", auth.WithAuth(group.TaskQueueWarmUp)) + g.POST("/:deploymentName/warmup", auth.WithAuth(group.TaskQueueWarmUp)) + g.POST("/:deploymentName/latest/warmup", auth.WithAuth(group.TaskQueueWarmUp)) + g.POST("/:deploymentName/v:version/warmup", auth.WithAuth(group.TaskQueueWarmUp)) + return group } @@ -75,7 +80,7 @@ func (g *taskQueueGroup) TaskQueuePut(ctx echo.Context) error { }) } - taskId, err := g.tq.put(ctx.Request().Context(), cc.AuthInfo, stubId, payload) + taskId, err := g.tq.put(ctx.Request().Context(), cc.AuthInfo, stubId, payload, false) if err != nil { if _, ok := err.(*types.ErrExceededTaskLimit); ok { return ctx.JSON(http.StatusTooManyRequests, map[string]interface{}{ @@ -92,3 +97,63 @@ func (g *taskQueueGroup) TaskQueuePut(ctx echo.Context) error { "task_id": taskId, }) } + +func (g *taskQueueGroup) TaskQueueWarmUp(ctx echo.Context) error { + cc, _ := ctx.(*auth.HttpAuthContext) + + stubId := ctx.Param("stubId") + deploymentName := ctx.Param("deploymentName") + version := ctx.Param("version") + + if deploymentName != "" { + var deployment *types.DeploymentWithRelated + + if version == "" { + var err error + deployment, err = g.tq.backendRepo.GetLatestDeploymentByName(ctx.Request().Context(), cc.AuthInfo.Workspace.Id, deploymentName, types.StubTypeTaskQueueDeployment, true) + if err != nil { + return apiv1.HTTPBadRequest("Invalid deployment") + } + } else { + version, err := strconv.Atoi(version) + if err != nil { + return apiv1.HTTPBadRequest("Invalid deployment version") + } + + deployment, err = g.tq.backendRepo.GetDeploymentByNameAndVersion(ctx.Request().Context(), cc.AuthInfo.Workspace.Id, deploymentName, uint(version), types.StubTypeTaskQueueDeployment) + if err != nil { + return apiv1.HTTPBadRequest("Invalid deployment") + } + } + + if deployment == nil { + return apiv1.HTTPBadRequest("Invalid deployment") + } + + if !deployment.Active { + return apiv1.HTTPBadRequest("Deployment is not active") + } + + stubId = deployment.Stub.ExternalId + } + + taskId, err := g.tq.put( + ctx.Request().Context(), + cc.AuthInfo, + stubId, + &types.TaskPayload{ + Args: nil, + Kwargs: make(map[string]interface{}), + }, + true, + ) + if err != nil { + return ctx.JSON(http.StatusInternalServerError, map[string]string{ + "error": err.Error(), + }) + } + + return ctx.JSON(http.StatusOK, map[string]interface{}{ + "task_id": taskId, + }) +} diff --git a/pkg/abstractions/taskqueue/taskqueue.go b/pkg/abstractions/taskqueue/taskqueue.go index 66232e1a7..95e8db79e 100644 --- a/pkg/abstractions/taskqueue/taskqueue.go +++ b/pkg/abstractions/taskqueue/taskqueue.go @@ -184,7 +184,7 @@ func (tq *RedisTaskQueue) getStubConfig(stubId string) (*types.StubConfigV1, err return config, nil } -func (tq *RedisTaskQueue) put(ctx context.Context, authInfo *auth.AuthInfo, stubId string, payload *types.TaskPayload) (string, error) { +func (tq *RedisTaskQueue) put(ctx context.Context, authInfo *auth.AuthInfo, stubId string, payload *types.TaskPayload, noOp bool) (string, error) { stubConfig, err := tq.getStubConfig(stubId) if err != nil { return "", err @@ -206,6 +206,10 @@ func (tq *RedisTaskQueue) put(ctx context.Context, authInfo *auth.AuthInfo, stub } policy.Expires = time.Now().Add(time.Duration(policy.TTL) * time.Second) + if noOp { + policy.NoOp = true + } + task, err := tq.taskDispatcher.SendAndExecute(ctx, string(types.ExecutorTaskQueue), authInfo, stubId, payload, policy) if err != nil { return "", err @@ -226,7 +230,7 @@ func (tq *RedisTaskQueue) TaskQueuePut(ctx context.Context, in *pb.TaskQueuePutR }, nil } - taskId, err := tq.put(ctx, authInfo, in.StubId, &payload) + taskId, err := tq.put(ctx, authInfo, in.StubId, &payload, false) return &pb.TaskQueuePutResponse{ Ok: err == nil, TaskId: taskId, @@ -262,6 +266,21 @@ func (tq *RedisTaskQueue) TaskQueuePop(ctx context.Context, in *pb.TaskQueuePopR continue } + if tm.NoOp { + tq.TaskQueueComplete( + ctx, + &pb.TaskQueueCompleteRequest{ + TaskId: tm.TaskId, + StubId: tm.StubId, + ContainerId: in.ContainerId, + TaskStatus: string(types.TaskStatusComplete), + KeepWarmSeconds: float32(instance.StubConfig.KeepWarmSeconds), + TaskDuration: 0, + }, + ) + continue + } + t, err := tq.backendRepo.GetTaskWithRelated(ctx, tm.TaskId) if err != nil { continue diff --git a/pkg/task/dispatch.go b/pkg/task/dispatch.go index e64c91553..2045cabb7 100644 --- a/pkg/task/dispatch.go +++ b/pkg/task/dispatch.go @@ -81,6 +81,7 @@ func (d *Dispatcher) Send(ctx context.Context, executor string, authInfo *auth.A taskMessage.Kwargs = payload.Kwargs taskMessage.Policy = policy taskMessage.Timestamp = time.Now().Unix() + taskMessage.NoOp = policy.NoOp taskFactory, exists := d.executors.Get(executor) if !exists { diff --git a/pkg/types/task.go b/pkg/types/task.go index 9acf9bbd0..984adf865 100644 --- a/pkg/types/task.go +++ b/pkg/types/task.go @@ -61,6 +61,7 @@ type TaskMessage struct { Policy TaskPolicy `json:"policy" redis:"policy"` Retries uint `json:"retries" redis:"retries"` Timestamp int64 `json:"timestamp" redis:"timestamp"` + NoOp bool `json:"no_op" redis:"no_op"` } func (tm *TaskMessage) Reset() { @@ -119,6 +120,7 @@ type TaskPolicy struct { Timeout int `json:"timeout" redis:"timeout"` Expires time.Time `json:"expires" redis:"expires"` TTL uint32 `json:"ttl" redis:"ttl"` + NoOp bool `json:"no_op" redis:"no_op"` } type ErrExceededTaskLimit struct { diff --git a/sdk/src/beta9/runner/endpoint.py b/sdk/src/beta9/runner/endpoint.py index ab0e73ee2..7cce32ba2 100644 --- a/sdk/src/beta9/runner/endpoint.py +++ b/sdk/src/beta9/runner/endpoint.py @@ -212,6 +212,14 @@ async def function( return self._create_response(body=task_lifecycle_data.result, status_code=status_code) + @self.app.post("/beta9/warmup") + async def warmup( + request: Request, + ): + payload = await request.json() + task_id = payload.get("task_id") + return self._create_response(body=task_id, status_code=HTTPStatus.OK) + def _create_response(self, *, body: Any, status_code: int = HTTPStatus.OK) -> Response: if isinstance(body, Response): return body