Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model executor for s3/gcs/azure to duckdb #6353

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions runtime/drivers/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ var spec = drivers.Spec{

type driver struct{}

type configProperties struct {
type ConfigProperties struct {
Account string `mapstructure:"azure_storage_account"`
Key string `mapstructure:"azure_storage_key"`
SASToken string `mapstructure:"azure_storage_sas_token"`
Expand All @@ -87,7 +87,7 @@ func (d driver) Open(instanceID string, config map[string]any, st *storage.Clien
return nil, errors.New("azure driver can't be shared")
}

conf := &configProperties{}
conf := &ConfigProperties{}
err := mapstructure.WeakDecode(config, conf)
if err != nil {
return nil, err
Expand All @@ -112,7 +112,7 @@ func (d driver) HasAnonymousSourceAccess(ctx context.Context, props map[string]a
}

conn := &Connection{
config: &configProperties{},
config: &ConfigProperties{},
logger: logger,
}

Expand All @@ -130,7 +130,7 @@ func (d driver) TertiarySourceConnectors(ctx context.Context, src map[string]any
}

type Connection struct {
config *configProperties
config *ConfigProperties
storage *storage.Client
logger *zap.Logger
}
Expand Down
3 changes: 3 additions & 0 deletions runtime/drivers/duckdb/duckdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,9 @@ func (c *connection) AsModelExecutor(instanceID string, opts *drivers.ModelExecu
if opts.InputHandle.Driver() == "mysql" || opts.InputHandle.Driver() == "postgres" {
return &sqlStoreToSelfExecutor{c}, true
}
if _, ok := opts.InputHandle.AsObjectStore(); ok {
return &objectStoreToSelfExecutor{c}, true
}
}
if opts.InputHandle == c {
if opts.OutputHandle.Driver() == "file" {
Expand Down
167 changes: 167 additions & 0 deletions runtime/drivers/duckdb/model_executor_objectstore_self.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
package duckdb

import (
"context"
"fmt"
"maps"
"strings"

"github.com/mitchellh/mapstructure"
"github.com/rilldata/rill/runtime/drivers"
"github.com/rilldata/rill/runtime/drivers/azure"
"github.com/rilldata/rill/runtime/drivers/gcs"
"github.com/rilldata/rill/runtime/drivers/s3"
"github.com/rilldata/rill/runtime/pkg/fileutil"
)

type s3InputProps struct {
Path string `mapstructure:"path"`
Format drivers.FileFormat `mapstructure:"format"`
DuckDB map[string]any `mapstructure:"duckdb"`
}

func (p *s3InputProps) Validate() error {
if p.Path == "" {
return fmt.Errorf("missing property `path`")
}
return nil
}

type objectStoreToSelfExecutor struct {
c *connection
}

var _ drivers.ModelExecutor = &objectStoreToSelfExecutor{}

func (e *objectStoreToSelfExecutor) Concurrency(desired int) (int, bool) {
if desired > 1 {
return 0, false
}
return 1, true
}

func (e *objectStoreToSelfExecutor) Execute(ctx context.Context, opts *drivers.ModelExecuteOptions) (*drivers.ModelResult, error) {
// Build the model executor options with updated input properties
clone := *opts
newInputProps, err := e.modelInputProperties(opts.ModelName, opts.InputConnector, opts.InputHandle, opts.InputProperties)
if err != nil {
return nil, err
}
clone.InputProperties = newInputProps
newOpts := &clone

// execute
executor := &selfToSelfExecutor{c: e.c}
return executor.Execute(ctx, newOpts)
}

func (e *objectStoreToSelfExecutor) modelInputProperties(model, inputConnector string, inputHandle drivers.Handle, inputProps map[string]any) (map[string]any, error) {
parsed := &s3InputProps{}
if err := mapstructure.WeakDecode(inputProps, parsed); err != nil {
return nil, fmt.Errorf("failed to parse input properties: %w", err)
}
if err := parsed.Validate(); err != nil {
return nil, fmt.Errorf("invalid input properties: %w", err)
}

m := &ModelInputProperties{}
var format string
if parsed.Format != "" {
format = fmt.Sprintf(".%s", parsed.Format)
} else {
format = fileutil.FullExt(parsed.Path)
}

config := inputHandle.Config()
// config properties can also be set as input properties
maps.Copy(config, inputProps)

// Generate secret SQL to access the service and set as pre_exec_query
safeSecretName := safeName(fmt.Sprintf("%s__%s__secret", model, inputConnector))
switch inputHandle.Driver() {
case "s3":
s3Config := &s3.ConfigProperties{}
err := mapstructure.WeakDecode(config, s3Config)
if err != nil {
return nil, fmt.Errorf("failed to parse s3 config properties: %w", err)
}
var sb strings.Builder
sb.WriteString("CREATE OR REPLACE TEMPORARY SECRET ")
sb.WriteString(safeSecretName)
sb.WriteString(" (TYPE S3")
if s3Config.AllowHostAccess {
sb.WriteString(", PROVIDER CREDENTIAL_CHAIN")
}
if s3Config.AccessKeyID != "" {
fmt.Fprintf(&sb, ", KEY_ID %s, SECRET %s", safeSQLString(s3Config.AccessKeyID), safeSQLString(s3Config.SecretAccessKey))
}
if s3Config.SessionToken != "" {
fmt.Fprintf(&sb, ", SESSION_TOKEN %s", safeSQLString(s3Config.SessionToken))
}
if s3Config.Endpoint != "" {
sb.WriteString(", ENDPOINT ")
sb.WriteString(safeSQLString(s3Config.Endpoint))
}
if s3Config.Region != "" {
sb.WriteString(", REGION ")
sb.WriteString(safeSQLString(s3Config.Region))
}
sb.WriteRune(')')
m.PreExec = sb.String()
case "gcs":
// GCS works via S3 compatibility mode
gcsConfig := &gcs.ConfigProperties{}
err := mapstructure.WeakDecode(config, gcsConfig)
if err != nil {
return nil, fmt.Errorf("failed to parse s3 config properties: %w", err)
}
var sb strings.Builder
sb.WriteString("CREATE OR REPLACE TEMPORARY SECRET ")
sb.WriteString(safeSecretName)
sb.WriteString(" (TYPE GCS")
if gcsConfig.AllowHostAccess {
sb.WriteString(", PROVIDER CREDENTIAL_CHAIN")
}
if gcsConfig.KeyID != "" {
fmt.Fprintf(&sb, ", KEY_ID %s, SECRET %s", safeSQLString(gcsConfig.KeyID), safeSQLString(gcsConfig.Secret))
}
sb.WriteRune(')')
m.PreExec = sb.String()
case "azure":
azureConfig := &azure.ConfigProperties{}
err := mapstructure.WeakDecode(config, azureConfig)
if err != nil {
return nil, fmt.Errorf("failed to parse s3 config properties: %w", err)
}
var sb strings.Builder
sb.WriteString("CREATE OR REPLACE TEMPORARY SECRET ")
sb.WriteString(safeSecretName)
sb.WriteString(" (TYPE AZURE")
if azureConfig.AllowHostAccess {
sb.WriteString(", PROVIDER CREDENTIAL_CHAIN")
}
if azureConfig.ConnectionString != "" {
fmt.Fprintf(&sb, ", CONNECTION_STRING %s", safeSQLString(azureConfig.ConnectionString))
}
if azureConfig.Account != "" {
fmt.Fprintf(&sb, ", ACCOUNT_NAME %s", safeSQLString(azureConfig.Account))
}
sb.WriteRune(')')
m.PreExec = sb.String()
default:
return nil, fmt.Errorf("internal error: unsupported object store: %s", inputHandle.Driver())
}

// Set SQL to read from the external source
from, err := sourceReader([]string{parsed.Path}, format, parsed.DuckDB)
if err != nil {
return nil, err
}
m.SQL = "SELECT * FROM " + from

propsMap := make(map[string]any)
if err := mapstructure.Decode(m, &propsMap); err != nil {
return nil, err
}
return propsMap, nil
}
9 changes: 6 additions & 3 deletions runtime/drivers/gcs/gcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,20 @@ var spec = drivers.Spec{

type driver struct{}

type configProperties struct {
type ConfigProperties struct {
SecretJSON string `mapstructure:"google_application_credentials"`
AllowHostAccess bool `mapstructure:"allow_host_access"`
// When working in s3 compatible mode
KeyID string `mapstructure:"key_id"`
Secret string `mapstructure:"secret"`
}

func (d driver) Open(instanceID string, config map[string]any, st *storage.Client, ac *activity.Client, logger *zap.Logger) (drivers.Handle, error) {
if instanceID == "" {
return nil, errors.New("gcs driver can't be shared")
}

conf := &configProperties{}
conf := &ConfigProperties{}
err := mapstructure.WeakDecode(config, conf)
if err != nil {
return nil, err
Expand Down Expand Up @@ -167,7 +170,7 @@ func parseSourceProperties(props map[string]any) (*sourceProperties, error) {
}

type Connection struct {
config *configProperties
config *ConfigProperties
storage *storage.Client
logger *zap.Logger
}
Expand Down
2 changes: 2 additions & 0 deletions runtime/drivers/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ type ConfigProperties struct {
AccessKeyID string `mapstructure:"aws_access_key_id"`
SecretAccessKey string `mapstructure:"aws_secret_access_key"`
SessionToken string `mapstructure:"aws_access_token"`
Endpoint string `mapstructure:"endpoint"`
Region string `mapstructure:"region"`
AllowHostAccess bool `mapstructure:"allow_host_access"`
RetainFiles bool `mapstructure:"retain_files"`
}
Expand Down
Loading