diff --git a/pkg/abstractions/image/build.go b/pkg/abstractions/image/build.go index 9c45dbf71..0dc07a175 100644 --- a/pkg/abstractions/image/build.go +++ b/pkg/abstractions/image/build.go @@ -43,6 +43,7 @@ type Builder struct { registry *common.ImageRegistry containerRepo repository.ContainerRepository tailscale *network.Tailscale + rdb *common.RedisClient } type BuildStep struct { @@ -130,13 +131,14 @@ func (o *BuildOpts) addPythonRequirements() { o.PythonPackages = append(filteredPythonPackages, baseRequirementsSlice...) } -func NewBuilder(config types.AppConfig, registry *common.ImageRegistry, scheduler *scheduler.Scheduler, tailscale *network.Tailscale, containerRepo repository.ContainerRepository) (*Builder, error) { +func NewBuilder(config types.AppConfig, registry *common.ImageRegistry, scheduler *scheduler.Scheduler, tailscale *network.Tailscale, containerRepo repository.ContainerRepository, rdb *common.RedisClient) (*Builder, error) { return &Builder{ config: config, scheduler: scheduler, tailscale: tailscale, registry: registry, containerRepo: containerRepo, + rdb: rdb, }, nil } @@ -285,6 +287,14 @@ func (b *Builder) Build(ctx context.Context, opts *BuildOpts, outputChan chan co return err } + err = b.rdb.Set(ctx, Keys.imageBuildContainerTTL(containerId), "1", time.Duration(imageContainerTtlS)*time.Second).Err() + if err != nil { + outputChan <- common.OutputMsg{Done: true, Success: false, Msg: "Failed to connect to build container.\n"} + return err + } + + go b.keepAlive(ctx, containerId, ctx.Done()) + conn, err := network.ConnectToHost(ctx, hostname, time.Second*30, b.tailscale, b.config.Tailscale) if err != nil { outputChan <- common.OutputMsg{Done: true, Success: false, Msg: "Failed to connect to build container.\n"} @@ -439,6 +449,22 @@ func (b *Builder) Exists(ctx context.Context, imageId string) bool { return b.registry.Exists(ctx, imageId) } +func (b *Builder) keepAlive(ctx context.Context, containerId string, done <-chan struct{}) { + ticker := time.NewTicker(time.Duration(buildContainerKeepAliveIntervalS) * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-done: + return + case <-ticker.C: + b.rdb.Set(ctx, Keys.imageBuildContainerTTL(containerId), "1", time.Duration(imageContainerTtlS)*time.Second).Err() + } + } +} + var imageNamePattern = regexp.MustCompile( `^` + // Assert position at the start of the string `(?:(?P(?:(?:localhost|[\w.-]+(?:\.[\w.-]+)+)(?::\d+)?)|[\w]+:\d+)\/)?` + // Optional registry, which can be localhost, a domain with optional port, or a simple registry with port diff --git a/pkg/abstractions/image/image.go b/pkg/abstractions/image/image.go index 6bd587ac0..95d43e8ef 100644 --- a/pkg/abstractions/image/image.go +++ b/pkg/abstractions/image/image.go @@ -3,6 +3,7 @@ package image import ( "context" "fmt" + "strings" "github.com/beam-cloud/beta9/pkg/auth" "github.com/beam-cloud/beta9/pkg/common" @@ -23,9 +24,12 @@ type ImageService interface { type RuncImageService struct { pb.UnimplementedImageServiceServer - builder *Builder - config types.AppConfig - backendRepo repository.BackendRepository + builder *Builder + config types.AppConfig + backendRepo repository.BackendRepository + rdb *common.RedisClient + keyEventChan chan common.KeyEvent + keyEventManager *common.KeyEventManager } type ImageServiceOpts struct { @@ -34,8 +38,12 @@ type ImageServiceOpts struct { BackendRepo repository.BackendRepository Scheduler *scheduler.Scheduler Tailscale *network.Tailscale + RedisClient *common.RedisClient } +const buildContainerKeepAliveIntervalS int = 10 +const imageContainerTtlS int = 60 + func NewRuncImageService( ctx context.Context, opts ImageServiceOpts, @@ -45,16 +53,30 @@ func NewRuncImageService( return nil, err } - builder, err := NewBuilder(opts.Config, registry, opts.Scheduler, opts.Tailscale, opts.ContainerRepo) + builder, err := NewBuilder(opts.Config, registry, opts.Scheduler, opts.Tailscale, opts.ContainerRepo, opts.RedisClient) if err != nil { return nil, err } - return &RuncImageService{ - builder: builder, - config: opts.Config, - backendRepo: opts.BackendRepo, - }, nil + keyEventManager, err := common.NewKeyEventManager(opts.RedisClient) + if err != nil { + return nil, err + } + + is := RuncImageService{ + builder: builder, + config: opts.Config, + backendRepo: opts.BackendRepo, + keyEventChan: make(chan common.KeyEvent), + keyEventManager: keyEventManager, + rdb: opts.RedisClient, + } + + go is.monitorImageContainers(ctx) + go is.keyEventManager.ListenForPattern(ctx, Keys.imageBuildContainerTTL("*"), is.keyEventChan) + go is.keyEventManager.ListenForPattern(ctx, common.RedisKeys.SchedulerContainerState("*"), is.keyEventChan) + + return &is, nil } func (is *RuncImageService) VerifyImageBuild(ctx context.Context, in *pb.VerifyImageBuildRequest) (*pb.VerifyImageBuildResponse, error) { @@ -184,6 +206,35 @@ func (is *RuncImageService) retrieveBuildSecrets(ctx context.Context, secrets [] return buildSecrets, nil } +func (is *RuncImageService) monitorImageContainers(ctx context.Context) { + for { + select { + case event := <-is.keyEventChan: + switch event.Operation { + case common.KeyOperationSet: + if strings.Contains(event.Key, common.RedisKeys.SchedulerContainerState("")) { + containerId := strings.TrimPrefix(is.keyEventManager.TrimKeyspacePrefix(event.Key), common.RedisKeys.SchedulerContainerState("")) + + if is.rdb.Exists(ctx, Keys.imageBuildContainerTTL(containerId)).Val() == 0 { + is.builder.scheduler.Stop(&types.StopContainerArgs{ + ContainerId: containerId, + Force: true, + }) + } + } + case common.KeyOperationExpired: + containerId := strings.TrimPrefix(is.keyEventManager.TrimKeyspacePrefix(event.Key), Keys.imageBuildContainerTTL("")) + is.builder.scheduler.Stop(&types.StopContainerArgs{ + ContainerId: containerId, + Force: true, + }) + } + case <-ctx.Done(): + return + } + } +} + func convertBuildSteps(buildSteps []*pb.BuildStep) []BuildStep { steps := make([]BuildStep, len(buildSteps)) for i, s := range buildSteps { @@ -194,3 +245,15 @@ func convertBuildSteps(buildSteps []*pb.BuildStep) []BuildStep { } return steps } + +var ( + imageBuildContainerTTL string = "image:build_container_ttl:%s" +) + +var Keys = &keys{} + +type keys struct{} + +func (k *keys) imageBuildContainerTTL(containerId string) string { + return fmt.Sprintf(imageBuildContainerTTL, containerId) +} diff --git a/pkg/common/key_events.go b/pkg/common/key_events.go index 9102805a6..800f16688 100644 --- a/pkg/common/key_events.go +++ b/pkg/common/key_events.go @@ -33,6 +33,10 @@ func NewKeyEventManager(rdb *RedisClient) (*KeyEventManager, error) { return &KeyEventManager{rdb: rdb}, nil } +func (kem *KeyEventManager) TrimKeyspacePrefix(key string) string { + return strings.TrimPrefix(key, keyspacePrefix) +} + func (kem *KeyEventManager) fetchExistingKeys(patternPrefix string) ([]string, error) { pattern := fmt.Sprintf("%s*", patternPrefix) @@ -49,10 +53,6 @@ func (kem *KeyEventManager) fetchExistingKeys(patternPrefix string) ([]string, e return trimmedKeys, nil } -func (kem *KeyEventManager) TrimKeyspacePrefix(key string) string { - return strings.TrimPrefix(key, keyspacePrefix) -} - func (kem *KeyEventManager) ListenForPattern(ctx context.Context, patternPrefix string, keyEventChan chan KeyEvent) error { existingKeys, err := kem.fetchExistingKeys(patternPrefix) if err != nil { diff --git a/pkg/gateway/gateway.go b/pkg/gateway/gateway.go index 8c5b7f7a2..14832da4f 100644 --- a/pkg/gateway/gateway.go +++ b/pkg/gateway/gateway.go @@ -251,6 +251,7 @@ func (g *Gateway) registerServices() error { Scheduler: g.Scheduler, Tailscale: g.Tailscale, BackendRepo: g.BackendRepo, + RedisClient: g.RedisClient, }) if err != nil { return err